1127 update to latest
139
dllm/.gitignore
vendored
Normal file
@ -0,0 +1,139 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.idea
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
applications/DeepSpeed-Chat/data
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Others
|
||||
/.vscode/
|
||||
/tmp/
|
||||
/data/
|
||||
/wandb/
|
||||
/logs/
|
||||
/models*/
|
||||
4
dllm/.gitmodules
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
[submodule "lm-evaluation-harness"]
|
||||
path = lm-evaluation-harness
|
||||
url = https://github.com/ZHZisZZ/lm-evaluation-harness
|
||||
branch = dllm
|
||||
21
dllm/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Zhanhui Zhou
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
283
dllm/README.md
Normal file
@ -0,0 +1,283 @@
|
||||
<h1 align="center">dLLM</h1>
|
||||
|
||||
<p align="center">
|
||||
Simple Diffusion Language Modeling
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="assets/logo.gif"
|
||||
alt="dLLM logo">
|
||||
</p>
|
||||
|
||||
|
||||
## Overview
|
||||
**dLLM** is a library that unifies the training and evaluation of **diffusion language models**, bringing transparency and reproducibility to the entire development pipeline:
|
||||
|
||||
<!-- and [RND1](https://www.radicalnumerics.ai/assets/rnd1_report.pdf) -->
|
||||
|
||||
- dLLM provides scalable training pipelines (inspired by [`transformers`](https://github.com/huggingface/transformers/blob/main/src/transformers) [Trainer](https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py)), with support for [LoRA](https://github.com/huggingface/peft), [DeepSpeed](https://github.com/deepspeedai/DeepSpeed) and [FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and beyond.
|
||||
|
||||
- dLLM provides unified evaluation pipelines (inspired by [`lm-evaluation-harness`](https://github.com/EleutherAI/lm-evaluation-harness)) that abstracts away inference details and making customization simple.
|
||||
|
||||
- Built on these components, dLLM provide the minimal **pretraining / finetuning / evaluation** recipes for open-weight models (e.g., [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)), and implementations of training algorithms (e.g., [Edit Flows](https://arxiv.org/abs/2506.09018)).
|
||||
|
||||
<!-- > [!NOTE]
|
||||
> This repository is primarily for educational purposes and does not aim for 100% exact reproduction of official models (which is impossible). We hope it serves as a helpful reference for the community — contributions and improvements are always welcome! -->
|
||||
|
||||
|
||||
## News
|
||||
**[2025/11]** We released a collection of BERTs finetuned for instruction-following: [`ModernBERT-{large,base}-chat-v0`](https://huggingface.co/collections/dllm-collection/bert-chat). This proof-of-concept shows that BERT’s internal knowledge can be leveraged for generative tasks via masked instruction tuning. See [ BERT Chat Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg) for detailed recipes, experimental results and lessons learned; See [`examples/bert`](/examples/bert) for training / inference / evaluation instructions.
|
||||
|
||||
|
||||
## Table of Contents
|
||||
- [Features](#features)
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Inference](#inference)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Citation](#citation)
|
||||
|
||||
|
||||
## Features
|
||||
<!-- - [`examples/rnd`](/examples/rnd): (WIP) Finetuning open-weight RND1 [RND1-Base](https://www.radicalnumerics.ai/assets/rnd1_report.pdf). -->
|
||||
- [`examples/llada`](/examples/llada): Pretraining, finetuning and evaluating LLaDA [LLaDA](https://arxiv.org/abs/2502.09992) / [LLaDA-MoE](https://arxiv.org/abs/2509.24389).
|
||||
- [`examples/dream`](/examples/dream): Pretraining, finetuning and evaluating Dream [Dream](https://arxiv.org/abs/2508.15487).
|
||||
- [`examples/bert`](/examples/bert): Finetuning any [BERT](https://arxiv.org/abs/1810.04805) to be lightweight Chatbots.
|
||||
<details>
|
||||
<summary>🎬 Click to show BERT Chat Demo</summary>
|
||||
|
||||
<p align="center">
|
||||
<img src="/examples/bert/assets/chat.gif" alt="chat" width="80%">
|
||||
</p>
|
||||
<p align="center">
|
||||
<em>
|
||||
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
|
||||
</em>
|
||||
</p>
|
||||
</details>
|
||||
- [`examples/editflow`](/examples/editflow): Educational reference for training [EditFlow](https://arxiv.org/abs/2506.09018) models, demonstrating how to extend existing DLLMs (e.g., LLaDA, Dream, BERT Chat) with *edit operations*—insertion, deletion, and substitution—and how to pretrain or finetune EditFlow models from scratch on public data.
|
||||
|
||||
<details>
|
||||
<summary>🎬 Click to show EditFlow Demo</summary>
|
||||
|
||||
<p align="center">
|
||||
<img src="/examples/editflow/assets/all.gif" alt="EditFlow demo" width="100%">
|
||||
</p>
|
||||
<p align="center"><em>EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.</em></p>
|
||||
|
||||
</details>
|
||||
- More upcoming.
|
||||
|
||||
|
||||
## Setup
|
||||
### Installation
|
||||
```bash
|
||||
# create and activate conda environment
|
||||
conda create -n dllm python=3.10 -y
|
||||
conda activate dllm
|
||||
|
||||
# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work)
|
||||
conda install cuda=12.4 -c nvidia
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
|
||||
--index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
# install dllm package
|
||||
pip install -e .
|
||||
```
|
||||
### (optional) Evaluation setup
|
||||
|
||||
```bash
|
||||
# initialize `lm-evaluation-harness` submodule
|
||||
git submodule update --init --recursive
|
||||
|
||||
# install submodule in editable mode with IFEval & Math dependencies
|
||||
pip install -e "lm-evaluation-harness[ifeval,math]"
|
||||
```
|
||||
|
||||
### (optional) Slurm setup
|
||||
For [Slurm](https://slurm.schedmd.com/) users, update [`scripts/train.slurm.sh`](/scripts/train.slurm.sh) for your cluster:
|
||||
```diff
|
||||
- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster
|
||||
- #SBATCH --quotatype=spot # Note: adjust this for your cluster
|
||||
+ #SBATCH --partition=YOUR_PARTITION
|
||||
+ #SBATCH --quotatype=YOUR_QUOTATYPE
|
||||
```
|
||||
Next, create a directory for your job logs:
|
||||
```shell
|
||||
mkdir logs
|
||||
```
|
||||
This folder will store the log files generated by your sbatch jobs.
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# modules for training / sampling
|
||||
dllm
|
||||
├── core # Core reusable modules shared across `dllm/pipelines`
|
||||
│ ├── generation
|
||||
│ ├── schedulers
|
||||
│ └── trainers
|
||||
├── data
|
||||
├── pipelines # Application-specific training & inference pipelines
|
||||
| ├── bert
|
||||
│ ├── dream
|
||||
│ ├── editflow
|
||||
│ └── llada
|
||||
│ ├── models # Model architecture and configs
|
||||
│ ├── generator.py # Generation utilities
|
||||
│ ├── trainer.py # Core training logic
|
||||
│ └── eval.py # Evaluation entry point
|
||||
├── tools
|
||||
└── utils
|
||||
|
||||
# entry points for training / sampling
|
||||
examples
|
||||
├── bert
|
||||
├── dream
|
||||
├── editflow
|
||||
└── llada
|
||||
├── chat.py # Interactive inference example
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
├── sft.py # Supervised finetuning example
|
||||
└── eval.sh # Evalution script
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
A typical training entry script looks like (for example, [`examples/llada/sft.py`](/examples/llada/sft.py)) looks like this:
|
||||
```python
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
dataset = "..."
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
args=training_args,
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
You can launch training job locally with `accelerate`, or submit it to a [Slurm](https://slurm.schedmd.com/) cluster using `sbatch`.
|
||||
```shell
|
||||
# Run locally (ZeRO-2 on 8 GPUs with 4bit quantization and LoRA)
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/llada/sft.py \
|
||||
--num_train_epochs 4 \
|
||||
--load_in_4bit True --lora True
|
||||
```
|
||||
```shell
|
||||
# Submit to a Slurm cluster (FSDP on 1 node, 8 GPUs)
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py" \
|
||||
--num_train_epochs 4
|
||||
|
||||
# Submit to a Slurm cluster (FSDP on 2 nodes, 16 GPUs)
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py" \
|
||||
--num_train_epochs 4
|
||||
```
|
||||
See [Features](#features) for specific training recipes.
|
||||
|
||||
|
||||
> Here are some useful tips for training:
|
||||
> 1. Use a subset of data:
|
||||
> `--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]"`
|
||||
> 2. Concatenate datasets:
|
||||
> `--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk"`
|
||||
> 3. Train with LoRA and 4bit quantization:
|
||||
> `--load_in_4bit True --lora True`
|
||||
> 4. Train with different distributed training methods:
|
||||
> `--accelerate_config "ddp,zero-{1,2,3},fsdp"`
|
||||
|
||||
## Inference
|
||||
|
||||
We provide unified [generators](/dllm/core/generation/generator.py) that abstracts away inference details.
|
||||
A typical inference entry script (for example, [`examples/llada/generate.py`](/examples/llada/generate.py)) looks like this:
|
||||
```python
|
||||
import dllm
|
||||
from dllm import llada
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
# for other models, change your generator and keep others unchanged
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.generate(inputs, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
```
|
||||
|
||||
You can also try interactive chat script (for example, [`examples/llada/chat.py`](/examples/llada/chat.py)) for visualized multi-turn dialogue:
|
||||
|
||||
<p align="center">
|
||||
<img src="/assets/chat.gif" alt="chat" width="80%">
|
||||
</p>
|
||||
<!-- <p align="center"><em>EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.</em></p> -->
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [`MMLU_Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro), run:
|
||||
```shell
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/llada/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "llada" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
|
||||
```
|
||||
|
||||
We also provide scripts to automatically evaluate [LLaDA](https://arxiv.org/abs/2502.09992), [Dream](https://arxiv.org/abs/2508.15487), and [BERT-Chat](https://huggingface.co/collections/dllm-collection/bert-chat) on all benchmarks.
|
||||
For example, you can launch [`examples/llada/eval.sh`](/examples/llada/eval.sh) directly using the following commands:
|
||||
```shell
|
||||
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" --instruct True
|
||||
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Base" --instruct False
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
```
|
||||
@misc{dllm,
|
||||
author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song},
|
||||
title = {dLLM: Simple Diffusion Language Modeling},
|
||||
year = {2025},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/ZHZisZZ/dllm}},
|
||||
}
|
||||
```
|
||||
BIN
dllm/assets/JetBrainsMono-VariableFont_wght.ttf
Normal file
BIN
dllm/assets/chat.gif
Normal file
|
After Width: | Height: | Size: 5.5 MiB |
BIN
dllm/assets/logo.gif
Normal file
|
After Width: | Height: | Size: 956 KiB |
BIN
dllm/assets/logo.png
Normal file
|
After Width: | Height: | Size: 5.0 KiB |
119
dllm/assets/logo.py
Normal file
@ -0,0 +1,119 @@
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os
|
||||
|
||||
# ---------- Configuration (smaller size) ----------
|
||||
W, H = 480, 210 # lower resolution
|
||||
TOTAL_DURATION = 3.0
|
||||
FPS = 15 # lower fps
|
||||
TEXT = "dLLM"
|
||||
# TEXT_COLOR = (235, 235, 235)
|
||||
TEXT_COLOR = (0, 0, 0)
|
||||
OUTPUT = "logo.gif"
|
||||
LAST_FRAME_PNG = "logo.png"
|
||||
|
||||
DIFFUSION_PORTION = 0.3 # fewer diffusion frames
|
||||
SEED = 8
|
||||
|
||||
|
||||
# ---------- Auto font size ----------
|
||||
def load_font_auto_size(text, w, h, target_width_ratio=0.95, target_height_ratio=0.95):
|
||||
lo, hi = 10, 2000
|
||||
best_font, best_size = None, lo
|
||||
while lo <= hi:
|
||||
mid = (lo + hi) // 2
|
||||
try:
|
||||
font = ImageFont.truetype(
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", size=mid
|
||||
)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
dummy = Image.new("L", (w, h), 0)
|
||||
d = ImageDraw.Draw(dummy)
|
||||
bbox = d.textbbox((0, 0), text, font=font)
|
||||
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
|
||||
width_ok = tw <= w * target_width_ratio
|
||||
height_ok = th <= h * target_height_ratio
|
||||
|
||||
if width_ok and height_ok:
|
||||
best_font, best_size = font, mid
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid - 1
|
||||
|
||||
return best_font if best_font is not None else font
|
||||
|
||||
|
||||
# ---------- Text rendering ----------
|
||||
def render_text_mask(w, h, text, font):
|
||||
img = Image.new("L", (w, h), 0)
|
||||
d = ImageDraw.Draw(img)
|
||||
bbox = d.textbbox((0, 0), text, font=font)
|
||||
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
x = (w - tw) // 2 - bbox[0]
|
||||
y = (h - th) // 2 - bbox[1]
|
||||
d.text((x, y), text, font=font, fill=255)
|
||||
return np.asarray(img, np.float32) / 255.0
|
||||
|
||||
|
||||
# ---------- Initialization ----------
|
||||
font = load_font_auto_size(TEXT, W, H)
|
||||
mask = render_text_mask(W, H, TEXT, font)
|
||||
|
||||
num_frames = int(TOTAL_DURATION * FPS)
|
||||
diffusion_frames = max(1, int(num_frames * DIFFUSION_PORTION))
|
||||
hold_ms = int((TOTAL_DURATION - diffusion_frames / FPS) * 1000)
|
||||
|
||||
rng = np.random.default_rng(SEED)
|
||||
frames = []
|
||||
|
||||
# ---------- Diffusion stage ----------
|
||||
for i in range(diffusion_frames):
|
||||
t = i / max(1, diffusion_frames - 1)
|
||||
progress = t**0.9
|
||||
noise_sigma = (1.0 - progress) ** 2.2
|
||||
|
||||
noise = rng.standard_normal((H, W, 1)).astype(np.float32)
|
||||
noise_img = 1.0 - noise_sigma * 0.5 * np.abs(noise)
|
||||
np.clip(noise_img, 0.0, 1.0, out=noise_img)
|
||||
|
||||
alpha = progress**2.0
|
||||
alpha_map = (mask * alpha).astype(np.float32)[..., None]
|
||||
|
||||
text_rgb = np.zeros((H, W, 3), dtype=np.float32)
|
||||
for c in range(3):
|
||||
text_rgb[..., c] = (mask > 0).astype(np.float32) * (TEXT_COLOR[c] / 255.0)
|
||||
|
||||
frame = (1.0 - alpha_map) * noise_img + alpha_map * text_rgb
|
||||
frame = (np.clip(frame, 0.0, 1.0) * 255).astype(np.uint8)
|
||||
frames.append(Image.fromarray(frame, mode="RGB"))
|
||||
|
||||
# ---------- Last frame ----------
|
||||
final_frame = frames[-1]
|
||||
|
||||
# ---------- Save last frame as PNG ----------
|
||||
final_frame.save(LAST_FRAME_PNG)
|
||||
print(f"🖼️ Last frame saved as: {LAST_FRAME_PNG}")
|
||||
|
||||
# ---------- Quantization (reduce size) ----------
|
||||
pal_frames = [f.convert("P", palette=Image.ADAPTIVE, colors=64) for f in frames]
|
||||
pal_final = final_frame.convert("P", palette=Image.ADAPTIVE, colors=64)
|
||||
|
||||
# ---------- Save GIF ----------
|
||||
normal_ms = int(1000 / FPS)
|
||||
durations = [normal_ms] * len(pal_frames) + [hold_ms]
|
||||
|
||||
pal_frames[0].save(
|
||||
OUTPUT,
|
||||
save_all=True,
|
||||
append_images=pal_frames[1:] + [pal_final],
|
||||
duration=durations,
|
||||
loop=0,
|
||||
optimize=True,
|
||||
)
|
||||
|
||||
print(f"✅ GIF saved: {OUTPUT}")
|
||||
print(
|
||||
f"Frames (diffusion only): {len(pal_frames)} at {FPS} FPS, final hold {hold_ms} ms, resolution {W}x{H}"
|
||||
)
|
||||
1
dllm/dllm/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import core, data, pipelines, utils
|
||||
1
dllm/dllm/core/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from dllm.core import trainers, schedulers, generation
|
||||
1
dllm/dllm/core/generation/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import generator, visualizer
|
||||
49
dllm/dllm/core/generation/generator.py
Normal file
@ -0,0 +1,49 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer, PreTrainedModel
|
||||
|
||||
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorOutput:
|
||||
sequences: torch.Tensor
|
||||
histories: list[torch.Tensor] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig:
|
||||
return_dict_in_generate: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGenerator(ABC):
|
||||
model: PreTrainedModel
|
||||
tokenizer: PreTrainedTokenizer
|
||||
scheduler: BaseAlphaScheduler | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.scheduler is None:
|
||||
self.scheduler = LinearAlphaScheduler()
|
||||
|
||||
@abstractmethod
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[torch.Tensor, list],
|
||||
config: GeneratorConfig | None = None,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@torch.no_grad()
|
||||
def infill(
|
||||
self,
|
||||
inputs: list[torch.Tensor, list],
|
||||
config: GeneratorConfig | None = None,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput:
|
||||
raise NotImplementedError
|
||||
427
dllm/dllm/core/generation/visualizer.py
Normal file
@ -0,0 +1,427 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVisualizer(ABC):
|
||||
tokenizer: PreTrainedTokenizer
|
||||
|
||||
@abstractmethod
|
||||
def visualize(self, history: list[torch.Tensor, list], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoVisualizer(BaseVisualizer):
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
history: list[torch.Tensor, list],
|
||||
output_path: str = "visualization.gif",
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class TerminalVisualizer(BaseVisualizer):
|
||||
|
||||
# Configuration (adjust as needed)
|
||||
HEADER_SIZE = 3 # Fixed number of lines for the header (0 if show_header is False)
|
||||
PROGRESS_SIZE = 3 # Fixed number of lines for the progress bar
|
||||
PANEL_PADDING_TOP = 1 # Top padding of the Panel (padding=(top, side))
|
||||
PANEL_PADDING_BOTTOM = 1 # Bottom padding of the Panel
|
||||
PANEL_PADDING_SIDE = 1 # Number of characters used for left and right padding
|
||||
PANEL_BORDER = 2 # Number of columns taken by the Panel border (usually 2)
|
||||
MIN_TOTAL_HEIGHT = 10 # Minimum terminal height (in lines)
|
||||
MAX_TOTAL_HEIGHT = 60 # Maximum terminal height to prevent overflowing the terminal
|
||||
DEFAULT_TERM_WIDTH = 120 # Default terminal width (in columns)
|
||||
ansi_escape = re.compile(r"\x1b\[[0-9;]*m") # Regex to match ANSI escape codes
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
history: list[torch.Tensor], # list of tokens per step: [T] or [B,T]
|
||||
fps: int = 16,
|
||||
rich: bool = True,
|
||||
title: str = "dllm",
|
||||
max_chars: int = None,
|
||||
every_n_steps: int = 1,
|
||||
show_header: bool = True,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize a masked-diffusion decoding trajectory stored in `history`.
|
||||
If items have batch dimension [B, T], visualize each sequence separately.
|
||||
"""
|
||||
try:
|
||||
# detect batch size
|
||||
first_step = history[0]
|
||||
if first_step.dim() > 1 and first_step.shape[0] > 1:
|
||||
B = first_step.shape[0]
|
||||
for b_idx in range(B):
|
||||
# build per-sequence history
|
||||
seq_history = [step[b_idx].unsqueeze(0) for step in history]
|
||||
self.visualize_one_history(
|
||||
seq_history,
|
||||
fps,
|
||||
rich,
|
||||
title=f"{title} (Batch {b_idx})",
|
||||
max_chars=max_chars,
|
||||
every_n_steps=every_n_steps,
|
||||
show_header=show_header,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
else:
|
||||
# no batch, just visualize normally
|
||||
self.visualize_one_history(
|
||||
history,
|
||||
fps,
|
||||
rich,
|
||||
title,
|
||||
max_chars,
|
||||
every_n_steps,
|
||||
show_header,
|
||||
skip_special_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"(Visualization skipped due to error: {e})")
|
||||
|
||||
def visualize_one_history(
|
||||
self,
|
||||
history: list[torch.Tensor], # list of tokens per step: [T] or [B,T]
|
||||
fps: int = 16,
|
||||
rich: bool = True,
|
||||
title: str = "dllm",
|
||||
max_chars: int = None,
|
||||
every_n_steps: int = 1, # re-render frequency (perf knob)
|
||||
show_header: bool = True,
|
||||
skip_special_tokens: bool = False, # NEW ARGUMENT
|
||||
) -> None:
|
||||
"""
|
||||
Visualize a masked-diffusion decoding trajectory stored in `history`.
|
||||
|
||||
Args:
|
||||
history: Sequence of token tensors for each step. Each item is [T] or [B,T].
|
||||
fps: Frames per second for the live UI (Rich) or sleep cadence for tqdm fallback.
|
||||
title: Header title.
|
||||
max_chars: Cap on rendered characters to keep terminal snappy.
|
||||
every_n_steps: Only redraw text every N steps (progress still updates every step).
|
||||
show_header: Show the magenta header bar (Rich path).
|
||||
skip_special_tokens: Whether to skip special/pad/eos tokens when rendering (default: False).
|
||||
Notes:
|
||||
- Masked positions are detected via `self.tokenizer.mask_token_id`.
|
||||
- Special tokens are determined via `self.tokenizer.all_special_ids`.
|
||||
- All layout, styling, and progress are encapsulated here.
|
||||
"""
|
||||
# --------- imports & env checks ----------
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.text import Text
|
||||
from rich.panel import Panel
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
)
|
||||
from rich.layout import Layout
|
||||
|
||||
_RICH_IMPORTED = True
|
||||
except Exception:
|
||||
_RICH_IMPORTED = False
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
_TQDM_IMPORTED = True
|
||||
except Exception:
|
||||
_TQDM_IMPORTED = False
|
||||
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"TerminalVisualizer.tokenizer must be set to a valid tokenizer."
|
||||
)
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
specials: set[int] = set(getattr(tokenizer, "all_special_ids", []) or [])
|
||||
self._specials = specials # store for helpers
|
||||
self._mask_token_id: Optional[int] = getattr(tokenizer, "mask_token_id", None)
|
||||
self._pad_token_id: Optional[int] = getattr(tokenizer, "pad_token_id", None)
|
||||
self._eos_token_id: Optional[int] = getattr(tokenizer, "eos_token_id", None)
|
||||
|
||||
# --------- helpers inside class scope ----------
|
||||
# (keep everything inside this class as requested)
|
||||
|
||||
# throttle settings
|
||||
sleep_s = 0.0 if fps <= 0 else 1.0 / float(max(1, fps))
|
||||
total_steps = len(history)
|
||||
every_n_steps = max(1, int(every_n_steps))
|
||||
|
||||
# decode final text up-front (used after render)
|
||||
final_text = self._detok(history[-1], skip_special_tokens=skip_special_tokens)
|
||||
final_text = self._truncate(final_text, max_chars)
|
||||
|
||||
# ------------------ new: estimate height from final_text ------------------
|
||||
import textwrap
|
||||
import shutil
|
||||
|
||||
def strip_ansi(s: str) -> str:
|
||||
return self.ansi_escape.sub("", s) if s else ""
|
||||
|
||||
def estimate_height_from_text(text: str, console_width: int) -> int:
|
||||
"""
|
||||
Estimate how many terminal rows the panel with `text` will need given console_width.
|
||||
Uses class constants for paddings/borders and header/progress sizes.
|
||||
"""
|
||||
plain = strip_ansi(text or "")
|
||||
# inner width = console width minus left/right panel paddings & border
|
||||
inner_width = max(
|
||||
10, console_width - 2 * self.PANEL_PADDING_SIDE - self.PANEL_BORDER
|
||||
)
|
||||
lines = 0
|
||||
# preserve existing newlines: wrap each paragraph separately
|
||||
for para in plain.splitlines() or [""]:
|
||||
if para.strip() == "":
|
||||
lines += 1
|
||||
continue
|
||||
wrapped = textwrap.wrap(
|
||||
para,
|
||||
width=inner_width,
|
||||
replace_whitespace=False,
|
||||
drop_whitespace=False,
|
||||
)
|
||||
lines += max(1, len(wrapped))
|
||||
text_block_lines = (
|
||||
lines + self.PANEL_PADDING_TOP + self.PANEL_PADDING_BOTTOM
|
||||
)
|
||||
extra = 2 # for panel title / subtitle / small margin
|
||||
header_h = self.HEADER_SIZE if show_header else 0
|
||||
total = header_h + text_block_lines + self.PROGRESS_SIZE + extra
|
||||
# clamp
|
||||
total = max(self.MIN_TOTAL_HEIGHT, min(total, self.MAX_TOTAL_HEIGHT))
|
||||
return int(total)
|
||||
|
||||
# try to detect terminal width; fallback to 100
|
||||
try:
|
||||
term_width = shutil.get_terminal_size().columns
|
||||
if not isinstance(term_width, int) or term_width <= 0:
|
||||
term_width = self.DEFAULT_TERM_WIDTH
|
||||
except Exception:
|
||||
term_width = self.DEFAULT_TERM_WIDTH
|
||||
|
||||
est_height = estimate_height_from_text(final_text, console_width=term_width)
|
||||
# ------------------ end new ----------------------------------------------
|
||||
|
||||
# choose rich or tqdm
|
||||
use_rich = bool(rich and _RICH_IMPORTED)
|
||||
|
||||
if not use_rich or not _RICH_IMPORTED:
|
||||
# ---------- tqdm fallback ----------
|
||||
if not _TQDM_IMPORTED:
|
||||
for i, toks in enumerate(history, start=1):
|
||||
if sleep_s > 0:
|
||||
time.sleep(sleep_s)
|
||||
print("\n✨ Generation complete!\n")
|
||||
print(final_text)
|
||||
return
|
||||
|
||||
pbar = tqdm(total=total_steps, desc="Diffusion", leave=True)
|
||||
for i, toks in enumerate(history, start=1):
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"masks": self._count_masks(toks),
|
||||
"pct": f"{int(100 * i / max(total_steps, 1))}%",
|
||||
}
|
||||
)
|
||||
if sleep_s > 0:
|
||||
time.sleep(sleep_s)
|
||||
pbar.close()
|
||||
print("\n✨ Generation complete!\n")
|
||||
if final_text:
|
||||
print(final_text)
|
||||
return
|
||||
|
||||
# ---------- rich live UI ----------
|
||||
# replaced fixed height=100 with the estimated height from history[-1]
|
||||
console = Console(
|
||||
force_terminal=True,
|
||||
color_system="truecolor",
|
||||
width=term_width,
|
||||
height=est_height,
|
||||
)
|
||||
layout = Layout()
|
||||
layout.split_column(
|
||||
(
|
||||
Layout(name="header", size=3)
|
||||
if show_header
|
||||
else Layout(name="header", size=0)
|
||||
),
|
||||
Layout(name="text", ratio=1),
|
||||
Layout(name="progress", size=3),
|
||||
)
|
||||
|
||||
progress = Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]Diffusion"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TextColumn("•"),
|
||||
TextColumn("[cyan]Masks: {task.fields[masks]}"),
|
||||
TextColumn("•"),
|
||||
TextColumn("[magenta]{task.fields[pct]:>4s}"),
|
||||
TimeRemainingColumn(),
|
||||
expand=True,
|
||||
)
|
||||
|
||||
init_masks = self._count_masks(history[0]) if history else 0
|
||||
task_id = progress.add_task(
|
||||
"Generating", total=total_steps, masks=init_masks, pct="0%"
|
||||
)
|
||||
|
||||
with Live(layout, console=console, refresh_per_second=max(1, fps)):
|
||||
for step_idx, toks in enumerate(history, start=1):
|
||||
if show_header:
|
||||
header = Text(title, style="bold magenta", justify="center")
|
||||
layout["header"].update(Panel(header, border_style="bright_blue"))
|
||||
|
||||
# progress bar
|
||||
masks_remaining = self._count_masks(toks)
|
||||
pct = f"{int(100 * step_idx / max(total_steps, 1))}%"
|
||||
progress.update(task_id, advance=1, masks=masks_remaining, pct=pct)
|
||||
|
||||
# text panel: decode whole sequence (avoids Ġ/Ċ artifacts)
|
||||
if (
|
||||
every_n_steps <= 1
|
||||
or (step_idx % every_n_steps == 0)
|
||||
or step_idx in (1, total_steps)
|
||||
):
|
||||
text_str = self._detok(
|
||||
toks, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
text_str = self._truncate(text_str, max_chars)
|
||||
text_rich = Text.from_ansi(text_str) if text_str else Text("")
|
||||
layout["text"].update(
|
||||
Panel(
|
||||
(
|
||||
text_rich
|
||||
if text_rich.plain
|
||||
else Text("[dim]— no tokens —[/dim]")
|
||||
),
|
||||
title="[bold]Generated Text",
|
||||
subtitle=f"[dim]Step {step_idx}/{total_steps}[/dim]",
|
||||
border_style="cyan",
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
|
||||
layout["progress"].update(Panel(progress))
|
||||
if sleep_s > 0:
|
||||
time.sleep(sleep_s)
|
||||
|
||||
console.print("\n[bold green]✨ Generation complete![/bold green]\n")
|
||||
# console.print(
|
||||
# Panel(
|
||||
# final_text if final_text else "[dim]— no decodable text —[/dim]",
|
||||
# title="[bold]Final Generated Text",
|
||||
# border_style="green",
|
||||
# padding=(1, 2),
|
||||
# )
|
||||
# )
|
||||
|
||||
# ======================== helpers (kept inside class) ========================
|
||||
|
||||
def _has_tty(self) -> bool:
|
||||
return sys.stdout.isatty() and os.environ.get("TERM", "") not in ("", "dumb")
|
||||
|
||||
def _first_item(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[0] if x.dim() > 1 else x
|
||||
|
||||
def _count_masks(self, toks: torch.Tensor) -> int:
|
||||
if getattr(self, "_mask_token_id", None) is None:
|
||||
return 0
|
||||
t = self._first_item(toks)
|
||||
return int((t == self._mask_token_id).sum().item())
|
||||
|
||||
def _detok(self, ids_or_tensor, *, skip_special_tokens: bool) -> str:
|
||||
"""
|
||||
Robust detokenize for list[int] / torch.Tensor([T]) / torch.Tensor([B,T]).
|
||||
Decode the whole sequence to avoid byte-level artifacts like Ġ/Ċ.
|
||||
"""
|
||||
tokenizer = self.tokenizer
|
||||
# normalize to python list[int]
|
||||
if isinstance(ids_or_tensor, torch.Tensor):
|
||||
t = self._first_item(ids_or_tensor).long()
|
||||
ids = t.tolist()
|
||||
elif isinstance(ids_or_tensor, (list, tuple)):
|
||||
ids = list(ids_or_tensor)
|
||||
else:
|
||||
# unknown type
|
||||
return ""
|
||||
|
||||
# Optionally drop specials/pad/eos *before* decode if desired
|
||||
if skip_special_tokens:
|
||||
keep = []
|
||||
specials = getattr(self, "_specials", set())
|
||||
pad_id = getattr(self, "_pad_token_id", None)
|
||||
eos_id = getattr(self, "_eos_token_id", None)
|
||||
for tid in ids:
|
||||
if tid in specials:
|
||||
continue
|
||||
if pad_id is not None and tid == pad_id:
|
||||
continue
|
||||
if eos_id is not None and tid == eos_id:
|
||||
continue
|
||||
keep.append(tid)
|
||||
ids = keep
|
||||
|
||||
# Prefer tokenizer.decode (handles Ġ/Ċ, merges properly)
|
||||
text = ""
|
||||
try:
|
||||
if hasattr(tokenizer, "decode"):
|
||||
text = tokenizer.decode(
|
||||
ids,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
else:
|
||||
# fallback: tokens -> string
|
||||
toks = tokenizer.convert_ids_to_tokens(ids)
|
||||
if hasattr(tokenizer, "convert_tokens_to_string"):
|
||||
text = tokenizer.convert_tokens_to_string(toks)
|
||||
else:
|
||||
text = " ".join(map(str, toks))
|
||||
except Exception:
|
||||
# extremely defensive fallback
|
||||
try:
|
||||
text = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
except Exception:
|
||||
text = ""
|
||||
|
||||
# sanitize control chars for terminal
|
||||
if text:
|
||||
text = text.replace("\r", "")
|
||||
return text
|
||||
|
||||
def _truncate(self, s: str, max_chars: Optional[int]) -> str:
|
||||
if max_chars is None or (isinstance(max_chars, int) and max_chars < 0):
|
||||
return s
|
||||
return s[:max_chars]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
2
dllm/dllm/core/schedulers/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .alpha import *
|
||||
from .kappa import *
|
||||
132
dllm/dllm/core/schedulers/alpha.py
Normal file
@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import math
|
||||
from typing import ClassVar, Dict, Type, Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
Number = Union[float, torch.Tensor]
|
||||
|
||||
|
||||
# ---------------- Registry-enabled Base ---------------- #
|
||||
@dataclasses.dataclass
|
||||
class BaseAlphaScheduler:
|
||||
__registry__: ClassVar[dict[str, type[BaseAlphaScheduler]]] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
BaseAlphaScheduler.__registry__[cls.__name__] = cls
|
||||
BaseAlphaScheduler.__registry__[cls.__name__.lower()] = cls
|
||||
|
||||
# Make instances callable (sched(i) -> alpha(i))
|
||||
def __call__(self, t: Number) -> Number:
|
||||
return self.alpha(t)
|
||||
|
||||
# ---- common API ----
|
||||
def alpha(self, i: Number) -> Number:
|
||||
i_t = torch.as_tensor(
|
||||
i,
|
||||
dtype=torch.float32,
|
||||
device=i.device if isinstance(i, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= i_t) & (i_t <= 1.0)):
|
||||
raise ValueError(f"i={i} not in [0,1]")
|
||||
out = self._alpha(i_t)
|
||||
return out.item() if isinstance(i, float) else out
|
||||
|
||||
def alpha_derivative(self, i: Number) -> Number:
|
||||
i_t = torch.as_tensor(
|
||||
i,
|
||||
dtype=torch.float32,
|
||||
device=i.device if isinstance(i, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= i_t) & (i_t <= 1.0)):
|
||||
raise ValueError(f"i={i} not in [0,1]")
|
||||
out = self._alpha_derivative(i_t)
|
||||
return out.item() if isinstance(i, float) else out
|
||||
|
||||
def reverse_mask_prob(self, s: Number, t: Number) -> Number:
|
||||
t_t = torch.as_tensor(
|
||||
t,
|
||||
dtype=torch.float32,
|
||||
device=t.device if isinstance(t, torch.Tensor) else None,
|
||||
)
|
||||
s_t = torch.as_tensor(
|
||||
s,
|
||||
dtype=torch.float32,
|
||||
device=s.device if isinstance(s, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= s_t) & (s_t < 1.0) & (0.0 < t_t) & (t_t <= 1.0)):
|
||||
raise ValueError(f"(t={t}, s={s}) out of range")
|
||||
if not torch.all(s_t < t_t):
|
||||
raise ValueError(f"Require s < t elementwise, but got (t={t}, s={s})")
|
||||
out = (1 - self(s_t)) / (1 - self(t_t))
|
||||
return out.item() if isinstance(t, float) and isinstance(s, float) else out
|
||||
|
||||
def weight(self, i: Number) -> Number:
|
||||
# w(t) = - α'(t) / (1 - α(t))
|
||||
return - self.alpha_derivative(i) / (1 - self.alpha(i) + 1e-6)
|
||||
|
||||
# ---- hooks implemented by subclasses ----
|
||||
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------- Implementations ---------------- #
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LinearAlphaScheduler(BaseAlphaScheduler):
|
||||
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
|
||||
return 1 - i
|
||||
|
||||
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
|
||||
return -torch.ones_like(i)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CosineAlphaScheduler(BaseAlphaScheduler):
|
||||
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
|
||||
return 1 - torch.cos((math.pi / 2) * (1 - i))
|
||||
|
||||
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
|
||||
return -(math.pi / 2) * torch.sin((math.pi / 2) * (1 - i))
|
||||
|
||||
|
||||
# ---------------- Factory helpers ---------------- #
|
||||
|
||||
|
||||
def get_alpha_scheduler_class(name: str) -> type[BaseAlphaScheduler]:
|
||||
"""Return the scheduler class by name (case-insensitive)."""
|
||||
cls = BaseAlphaScheduler.__registry__.get(
|
||||
name
|
||||
) or BaseAlphaScheduler.__registry__.get(name.lower())
|
||||
if cls is None:
|
||||
available = sorted(k for k in BaseAlphaScheduler.__registry__ if k[0].isupper())
|
||||
raise ValueError(f"Unknown scheduler '{name}'. Available: {available}")
|
||||
return cls
|
||||
|
||||
|
||||
def make_alpha_scheduler(name: str, **kwargs: Any) -> BaseAlphaScheduler:
|
||||
"""Instantiate a scheduler by name with optional kwargs."""
|
||||
cls = get_alpha_scheduler_class(name)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
# ---------------- Example usage ---------------- #
|
||||
|
||||
if __name__ == "__main__":
|
||||
lin_sched = make_alpha_scheduler("LinearalphaScheduler")
|
||||
print("Linear α(0.5):", lin_sched.alpha(0.5))
|
||||
print("Linear w(0.5):", lin_sched.weight(0.5))
|
||||
print("Linear α([.25,.5,.75]):", lin_sched.alpha(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("==========================================")
|
||||
cos_sched = make_alpha_scheduler("CosinealphaScheduler")
|
||||
print("Cosine α(0.5):", cos_sched.alpha(0.5))
|
||||
print("Cosine w(0.5):", cos_sched.weight(0.5))
|
||||
print("Cosine α([.25,.5,.75]):", cos_sched.alpha(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
|
||||
128
dllm/dllm/core/schedulers/kappa.py
Normal file
@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import math
|
||||
from typing import ClassVar, Dict, Type, Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
Number = Union[float, torch.Tensor]
|
||||
|
||||
|
||||
# ---------------- Registry-enabled Base ---------------- #
|
||||
@dataclasses.dataclass
|
||||
class BaseKappaScheduler:
|
||||
__registry__: ClassVar[dict[str, type[BaseKappaScheduler]]] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
BaseKappaScheduler.__registry__[cls.__name__] = cls
|
||||
BaseKappaScheduler.__registry__[cls.__name__.lower()] = cls
|
||||
|
||||
# Make instances callable (sched(t) -> kappa(t))
|
||||
def __call__(self, t: Number) -> Number:
|
||||
return self.kappa(t)
|
||||
|
||||
# ---- common API ----
|
||||
def kappa(self, t: Number) -> Number:
|
||||
t_tensor = torch.as_tensor(
|
||||
t,
|
||||
dtype=torch.float32,
|
||||
device=t.device if isinstance(t, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)):
|
||||
raise ValueError(f"t={t} not in [0,1]")
|
||||
out = self._kappa(t_tensor)
|
||||
return out.item() if isinstance(t, float) else out
|
||||
|
||||
def kappa_derivative(self, t: Number) -> Number:
|
||||
t_tensor = torch.as_tensor(
|
||||
t,
|
||||
dtype=torch.float32,
|
||||
device=t.device if isinstance(t, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)):
|
||||
raise ValueError(f"t={t} not in [0,1]")
|
||||
out = self._kappa_derivative(t_tensor)
|
||||
return out.item() if isinstance(t, float) else out
|
||||
|
||||
def weight(self, t: Number) -> Number:
|
||||
# w(t) = κ'(t) / (1 - κ(t))
|
||||
return self.kappa_derivative(t) / (1 - self.kappa(t) + 1e-6)
|
||||
|
||||
# ---- hooks implemented by subclasses ----
|
||||
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------- Implementations ---------------- #
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CubicKappaScheduler(BaseKappaScheduler):
|
||||
a: float = 1.0
|
||||
b: float = 1.0
|
||||
|
||||
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
|
||||
# κ(t) = (a+1) t^3 - (a+b+1) t^2 + (b+1) t
|
||||
return (self.a + 1) * (t**3) - (self.a + self.b + 1) * (t**2) + (self.b + 1) * t
|
||||
|
||||
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
|
||||
# κ'(t) = 3(a+1) t^2 - 2(a+b+1) t + (b+1)
|
||||
return 3 * (self.a + 1) * (t**2) - 2 * (self.a + self.b + 1) * t + (self.b + 1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LinearKappaScheduler(CubicKappaScheduler):
|
||||
# Special case: κ(t) = t corresponds to a=-1, b=0
|
||||
a: float = -1.0
|
||||
b: float = 0.0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CosineKappaScheduler(BaseKappaScheduler):
|
||||
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
|
||||
# κ(t) = 1 - cos((π/2) * t)
|
||||
return 1.0 - torch.cos(0.5 * math.pi * t)
|
||||
|
||||
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
|
||||
# κ'(t) = (π/2) * sin((π/2) * t)
|
||||
return 0.5 * math.pi * torch.sin(0.5 * math.pi * t)
|
||||
|
||||
|
||||
# ---------------- Factory helpers ---------------- #
|
||||
|
||||
|
||||
def get_kappa_scheduler_class(name: str) -> type[BaseKappaScheduler]:
|
||||
"""Return the scheduler class by name (case-insensitive)."""
|
||||
cls = BaseKappaScheduler.__registry__.get(
|
||||
name
|
||||
) or BaseKappaScheduler.__registry__.get(name.lower())
|
||||
if cls is None:
|
||||
available = sorted(k for k in BaseKappaScheduler.__registry__ if k[0].isupper())
|
||||
raise ValueError(f"Unknown scheduler '{name}'. Available: {available}")
|
||||
return cls
|
||||
|
||||
|
||||
def make_kappa_scheduler(name: str, **kwargs: Any) -> BaseKappaScheduler:
|
||||
"""Instantiate a scheduler by name with optional kwargs."""
|
||||
cls = get_kappa_scheduler_class(name)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
# ---------------- Example usage ---------------- #
|
||||
|
||||
if __name__ == "__main__":
|
||||
lin_sched = make_kappa_scheduler("LinearKappaScheduler")
|
||||
print("Linear κ(0.5):", lin_sched.kappa(0.5))
|
||||
print("Linear w(0.5):", lin_sched.weight(0.5))
|
||||
print("Linear κ([.25,.5,.75]):", lin_sched.kappa(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("==========================================")
|
||||
cos_sched = make_kappa_scheduler("CosineKappaScheduler")
|
||||
print("Cosine κ(0.5):", cos_sched.kappa(0.5))
|
||||
print("Cosine w(0.5):", cos_sched.weight(0.5))
|
||||
print("Cosine κ([.25,.5,.75]):", cos_sched.kappa(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
|
||||
1
dllm/dllm/core/trainers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from dllm.core.trainers.mdlm import MDLMTrainer
|
||||
140
dllm/dllm/core/trainers/mdlm.py
Normal file
@ -0,0 +1,140 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from typing import Any
|
||||
|
||||
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
|
||||
|
||||
|
||||
class MDLMTrainer(transformers.Trainer):
|
||||
"""
|
||||
Masked Diffusion Language Model Trainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
scheduler: BaseAlphaScheduler | None = None,
|
||||
time_epsilon: float = 1e-3,
|
||||
loss_weight_type: str = "scheduler", # "ones"
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.scheduler = scheduler or LinearAlphaScheduler()
|
||||
if not (0.0 < time_epsilon < 1.0):
|
||||
raise ValueError("time_epsilon must be in (0, 1)")
|
||||
self.time_epsilon = time_epsilon
|
||||
self.loss_weight_type = loss_weight_type
|
||||
|
||||
def _preprocess_inputs(self, inputs):
|
||||
pass
|
||||
|
||||
def _postprocess_outputs(self, outputs):
|
||||
pass
|
||||
|
||||
def _compute_loss_weights(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
inputs: dict[str, Any],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Compute loss weights given timestep t and other arguments."""
|
||||
b, l = inputs["input_ids"].shape
|
||||
if self.loss_weight_type == "scheduler":
|
||||
loss_weights = self.scheduler.weight(t).unsqueeze(1).repeat(1, l) # b, 1
|
||||
elif self.loss_weight_type == "ones":
|
||||
loss_weights = torch.ones_like(inputs["input_ids"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss_weights
|
||||
|
||||
@torch.no_grad()
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
if prediction_loss_only:
|
||||
return (loss.detach(), None, None)
|
||||
|
||||
logits = getattr(outputs, "logits", outputs)
|
||||
if isinstance(logits, torch.Tensor):
|
||||
logits = logits.detach().contiguous()
|
||||
|
||||
labels = inputs.get("labels")
|
||||
if isinstance(labels, torch.Tensor):
|
||||
labels = labels.detach().contiguous()
|
||||
|
||||
return (loss.detach(), logits, labels)
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: transformers.PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
assert self.processing_class.padding_side == "right"
|
||||
self._preprocess_inputs(inputs)
|
||||
input_ids, labels, attention_mask = (
|
||||
inputs["input_ids"],
|
||||
inputs["labels"],
|
||||
inputs.get("attention_mask", None),
|
||||
)
|
||||
b, l = input_ids.shape
|
||||
|
||||
# === 1. Sample diffusion timesteps ===
|
||||
# Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0.
|
||||
# The scheduler defines the masking rate α(t); we convert it to a masking probability p_mask = 1 - α(t).
|
||||
t = self.time_epsilon + (1 - self.time_epsilon) * torch.rand(
|
||||
b, device=input_ids.device
|
||||
)
|
||||
p_mask = 1 - self.scheduler(t).unsqueeze(1).expand(b, l)
|
||||
|
||||
# === 2. Apply stochastic masking ===
|
||||
# Tokens are masked independently according to p_mask(t).
|
||||
# Positions with label = -100 are excluded (ignored in loss).
|
||||
masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & (
|
||||
labels != -100
|
||||
)
|
||||
# Replace masked tokens with the special [MASK] token.
|
||||
noised_input_ids = torch.where(
|
||||
masked_indices, self.processing_class.mask_token_id, input_ids
|
||||
)
|
||||
|
||||
# === 3. Forward pass through the model ===
|
||||
# The model predicts clean tokens given noised inputs.
|
||||
outputs = model(input_ids=noised_input_ids, attention_mask=attention_mask)
|
||||
self._postprocess_outputs(outputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# === 4. Handle degenerate cases (no tokens masked) ===
|
||||
# If no positions were masked, return a zero loss to keep gradients valid.
|
||||
# This step is necessary for Deepspeed Zero-{2,3}
|
||||
if not masked_indices.any():
|
||||
return (
|
||||
(logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0
|
||||
)
|
||||
|
||||
# === 5. Compute per-token loss weights ===
|
||||
# Depending on the configuration, weights may depend on timestep t
|
||||
# (e.g., scheduler-based) or be uniform (ones).
|
||||
loss_weights = self._compute_loss_weights(
|
||||
t=t, inputs=inputs, masked_indices=masked_indices
|
||||
)
|
||||
|
||||
# === 6. Compute weighted cross-entropy ===
|
||||
# Only masked tokens contribute to the loss.
|
||||
assert (input_ids[masked_indices] == labels[masked_indices]).all()
|
||||
token_loss = F.cross_entropy(
|
||||
logits[masked_indices], input_ids[masked_indices], reduction="none"
|
||||
)
|
||||
token_loss = token_loss * loss_weights[masked_indices]
|
||||
|
||||
# === 7. Normalize loss per effective token length ===
|
||||
# Normalize each sequence’s contribution by its number of valid tokens,
|
||||
# then average over the batch for stability across variable-length inputs.
|
||||
effective_lengths = torch.sum(labels != -100, dim=1, keepdim=True).expand(b, l)
|
||||
loss = torch.sum(token_loss / effective_lengths[masked_indices]) / b
|
||||
|
||||
# === 8. Return final loss (and optionally model outputs) ===
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
1
dllm/dllm/data/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .utils import load_sft_dataset, load_pt_dataset
|
||||
63
dllm/dllm/data/alpaca.py
Normal file
@ -0,0 +1,63 @@
|
||||
from typing import Optional
|
||||
from datasets import load_dataset, DatasetDict
|
||||
|
||||
|
||||
def _build_alpaca_prompt(instruction: str, input_text: str | None) -> str:
|
||||
"""Construct a clean text prompt from Alpaca fields.
|
||||
|
||||
We intentionally *do not* include Anthropic-style role tags (e.g., "Human:", "Assistant:")
|
||||
in the returned prompt, to mirror the return shape of `load_hh_rlhf_dataset` which removes
|
||||
those tags from the prompt it returns.
|
||||
"""
|
||||
instruction = (instruction or "").strip()
|
||||
input_text = (input_text or "").strip()
|
||||
|
||||
if input_text:
|
||||
# Keep instruction and input separated by a blank line for readability.
|
||||
return f"{instruction}\n\n{input_text}"
|
||||
else:
|
||||
return instruction
|
||||
|
||||
|
||||
def load_dataset_alpaca(dataset_name_or_path: str) -> DatasetDict:
|
||||
"""Load the Alpaca dataset (tatsu-lab/alpaca) and expose unified fields.
|
||||
|
||||
Returns a `DatasetDict` where each split contains:
|
||||
- prompt: Combined instruction (+ optional input), with clean formatting
|
||||
- response: The target output (model answer)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_name_or_path : str
|
||||
Usually "tatsu-lab/alpaca" or a local path.
|
||||
"""
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
|
||||
def map_fn(example):
|
||||
prompt = _build_alpaca_prompt(
|
||||
example.get("instruction", ""), example.get("input", "")
|
||||
)
|
||||
response = (example.get("output", "") or "").strip()
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": response},
|
||||
]
|
||||
}
|
||||
|
||||
dataset = dataset.map(
|
||||
map_fn, remove_columns=dataset["train"].column_names, num_proc=4
|
||||
)
|
||||
# make train test split
|
||||
dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dllm.utils import resolve_with_base_env
|
||||
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
"tatsu-lab/alpaca", "BASE_DATASETS_DIR"
|
||||
)
|
||||
dataset = load_dataset_alpaca(dataset_name_or_path)
|
||||
breakpoint()
|
||||
133
dllm/dllm/data/opc.py
Normal file
@ -0,0 +1,133 @@
|
||||
from typing import Optional, Text, List, Dict
|
||||
from datasets import (
|
||||
load_dataset,
|
||||
get_dataset_config_names,
|
||||
concatenate_datasets,
|
||||
DatasetDict,
|
||||
Dataset,
|
||||
IterableDatasetDict,
|
||||
)
|
||||
from dllm.data.utils import (
|
||||
_merge_datasetdicts,
|
||||
_merge_iterabledatasetdicts,
|
||||
_ensure_datasetdict,
|
||||
_ensure_iterabledatasetdict,
|
||||
_ensure_datasetdict,
|
||||
)
|
||||
|
||||
|
||||
def load_dataset_opc_sft(
|
||||
dataset_name_or_path: str, name: str | None = None, lang: str | None = None
|
||||
) -> DatasetDict:
|
||||
"""
|
||||
Load OpenCoder OPC SFT dataset(s) and produce a DatasetDict with a train/test split.
|
||||
- If `name` is provided: load that specific config.
|
||||
- If `name` is None: load *all* available configs and concatenate them.
|
||||
"""
|
||||
|
||||
def _map_to_messages(ds: Dataset) -> Dataset:
|
||||
def map_fn(example):
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": example["instruction"]},
|
||||
{"role": "assistant", "content": example["output"]},
|
||||
]
|
||||
}
|
||||
|
||||
# Remove all original columns after mapping
|
||||
remove_cols = ds.column_names
|
||||
return ds.map(map_fn, remove_columns=remove_cols, num_proc=4)
|
||||
|
||||
def _load_one_config(dataset_name_or_path: str, cfg_name: str) -> Dataset:
|
||||
ds = load_dataset(dataset_name_or_path, cfg_name, split="train")
|
||||
return _map_to_messages(ds)
|
||||
|
||||
if name is not None:
|
||||
train_ds = _load_one_config(dataset_name_or_path, name)
|
||||
else:
|
||||
# Enumerate and load all configs, then concatenate
|
||||
cfgs: list[str] = get_dataset_config_names(dataset_name_or_path)
|
||||
if not cfgs:
|
||||
raise ValueError(f"No configs found for dataset: {dataset_name_or_path}")
|
||||
parts = [_load_one_config(dataset_name_or_path, c) for c in cfgs]
|
||||
train_ds = concatenate_datasets(parts)
|
||||
|
||||
# Final split
|
||||
ds_dict = train_ds.train_test_split(test_size=0.1, seed=42)
|
||||
if lang is not None:
|
||||
ds_dict = ds_dict.filter(lambda row: lang in row["messages"][1]["content"])
|
||||
|
||||
return DatasetDict(ds_dict)
|
||||
|
||||
|
||||
def load_dataset_opc_annealing(
|
||||
dataset_name_or_path: str,
|
||||
name: str | None = None,
|
||||
lang: str | None = None,
|
||||
streaming: bool = True,
|
||||
) -> DatasetDict:
|
||||
def _load_one_config(_name):
|
||||
ds = load_dataset(
|
||||
dataset_name_or_path, _name, split="train", streaming=streaming
|
||||
)
|
||||
if lang:
|
||||
if _name in ["synthetic_code_snippet", "algorithmic_corpus"]:
|
||||
ds = ds.filter(lambda row: row["lang"] == lang)
|
||||
elif _name in ["synthetic_qa"]:
|
||||
ds = ds.filter(lambda row: row["program_lang"] == lang)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# return IterableDatasetDict({"train": ds})
|
||||
if streaming:
|
||||
return _ensure_iterabledatasetdict(ds)
|
||||
return _ensure_datasetdict(ds)
|
||||
|
||||
if name is not None:
|
||||
return _load_one_config(name)
|
||||
|
||||
if streaming:
|
||||
parts = [
|
||||
_load_one_config(name)
|
||||
for name in get_dataset_config_names(dataset_name_or_path)
|
||||
]
|
||||
merged = parts[0]
|
||||
for p in parts[1:]:
|
||||
merged = _merge_iterabledatasetdicts(merged, p)
|
||||
return merged
|
||||
else:
|
||||
parts = [
|
||||
_load_one_config(name)
|
||||
for name in get_dataset_config_names(dataset_name_or_path)
|
||||
]
|
||||
if len(parts) == 1:
|
||||
return _ensure_datasetdict(parts[0])
|
||||
merged = parts[0]
|
||||
for p in parts[1:]:
|
||||
merged = _merge_datasetdicts(merged, p)
|
||||
return _ensure_datasetdict(merged)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dllm.utils import resolve_with_base_env
|
||||
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
"OpenCoder-LLM/opc-sft-stage1", "BASE_DATASETS_DIR"
|
||||
)
|
||||
# If you want a specific config:
|
||||
dataset_edu = load_dataset_opc_sft(dataset_name_or_path, "realuser_instruct")
|
||||
# Otherwise, all configs concatenated:
|
||||
dataset_all = load_dataset_opc_sft(dataset_name_or_path, None)
|
||||
dataset_all_python = load_dataset_opc_sft(dataset_name_or_path, None, "python")
|
||||
breakpoint()
|
||||
|
||||
# streaming = True
|
||||
# dataset_name_or_path = resolve_with_base_env(
|
||||
# "OpenCoder-LLM/opc-annealing-corpus", "BASE_DATASETS_DIR"
|
||||
# )
|
||||
# # If you want a specific config:
|
||||
# dataset_alg_all = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus")
|
||||
# dataset_alg_python = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus", "python")
|
||||
# # Otherwise, all configs concatenated:
|
||||
# dataset_all_python = load_dataset_opc_annealing(dataset_name_or_path, None, "python")
|
||||
# dataset_all_all = load_dataset_opc_annealing(dataset_name_or_path, None)
|
||||
# breakpoint()
|
||||
108
dllm/dllm/data/ultrachat.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import Optional, List, Dict
|
||||
from datasets import load_dataset, DatasetDict
|
||||
|
||||
|
||||
def _extract_first_turn(messages: list[dict[str, str]]) -> dict[str, str] | None:
|
||||
"""
|
||||
Given a list of chat messages like:
|
||||
[{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
...]
|
||||
return a dict with the first user/assistant exchange as:
|
||||
{"prompt": <user content>, "response": <assistant content>}
|
||||
If no valid first turn exists, return None.
|
||||
"""
|
||||
if not isinstance(messages, list) or len(messages) < 2:
|
||||
return None
|
||||
|
||||
# Find the first user message and the first assistant *after* that user msg
|
||||
# (Most entries start as [user, assistant, ...], but we guard anyway.)
|
||||
user_idx = None
|
||||
for i, m in enumerate(messages):
|
||||
if (
|
||||
isinstance(m, dict)
|
||||
and m.get("role") == "user"
|
||||
and isinstance(m.get("content"), str)
|
||||
):
|
||||
user_idx = i
|
||||
break
|
||||
if user_idx is None:
|
||||
return None
|
||||
|
||||
# Find first assistant after that user
|
||||
for j in range(user_idx + 1, len(messages)):
|
||||
m = messages[j]
|
||||
if (
|
||||
isinstance(m, dict)
|
||||
and m.get("role") == "assistant"
|
||||
and isinstance(m.get("content"), str)
|
||||
):
|
||||
user_text = messages[user_idx]["content"].strip()
|
||||
assistant_text = m["content"].strip()
|
||||
if user_text and assistant_text:
|
||||
return {"prompt": user_text, "response": assistant_text}
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def load_dataset_ultrachat(dataset_name_or_path: str) -> DatasetDict:
|
||||
"""
|
||||
Load the UltraChat 200k dataset (HuggingFaceH4/ultrachat_200k) and keep only the *first turn*
|
||||
(first user message and the assistant reply).
|
||||
|
||||
Returns a `DatasetDict` where each split contains:
|
||||
- prompt: first user message content
|
||||
- response: first assistant reply content
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_name_or_path : str
|
||||
Typically "HuggingFaceH4/ultrachat_200k" or a local path.
|
||||
data_dir : Optional[str]
|
||||
Optional subdirectory (for local paths).
|
||||
"""
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
|
||||
# We only keep examples that have a valid first (user, assistant) turn.
|
||||
def has_first_turn(example):
|
||||
messages = example.get("messages")
|
||||
return _extract_first_turn(messages) is not None
|
||||
|
||||
dataset = dataset.filter(has_first_turn, num_proc=4)
|
||||
|
||||
def map_fn(example):
|
||||
first = _extract_first_turn(example["messages"])
|
||||
# Fallbacks for robustness (shouldn't be hit after filter, but just in case)
|
||||
if first is None:
|
||||
first = {"prompt": (example.get("prompt") or "").strip(), "response": ""}
|
||||
return {"prompt": first["prompt"], "response": first["response"]}
|
||||
|
||||
# Remove original columns for a clean schema (infer from any available split)
|
||||
cols_to_remove = None
|
||||
for split_name in dataset.keys():
|
||||
cols_to_remove = dataset[split_name].column_names
|
||||
break
|
||||
|
||||
dataset = dataset.map(map_fn, remove_columns=cols_to_remove, num_proc=4)
|
||||
dataset = DatasetDict(
|
||||
{
|
||||
new: dataset[old]
|
||||
for old, new in {
|
||||
"train_sft": "train",
|
||||
"test_sft": "test",
|
||||
}.items()
|
||||
if old in dataset
|
||||
}
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Mirrors the style from your previous loaders: resolve path via env helper if available.
|
||||
from dllm.utils import resolve_with_base_env
|
||||
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
"HuggingFaceH4/ultrachat_200k", "BASE_DATASETS_DIR"
|
||||
)
|
||||
dataset = load_dataset_ultrachat(dataset_name_or_path)
|
||||
breakpoint()
|
||||
377
dllm/dllm/data/utils.py
Normal file
@ -0,0 +1,377 @@
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDatasetDict,
|
||||
IterableDataset,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
|
||||
from dllm.utils.utils import resolve_with_base_env, parse_spec, get_default_logger
|
||||
|
||||
|
||||
logger = get_default_logger(__name__)
|
||||
|
||||
|
||||
def load_sft_dataset(
|
||||
dataset_args: str, load_preprocessed_data: bool = False
|
||||
) -> DatasetDict:
|
||||
"""
|
||||
Examples of dataset_args:
|
||||
- "tatsu-lab/alpaca"
|
||||
- "OpenCoder-LLM/opc-sft-stage2[name:educational_instruct]"
|
||||
- "tatsu-lab/alpaca[train:5000]"
|
||||
- "tatsu-lab/alpaca[train:5000] | HuggingFaceH4/ultrachat_200k[train:5000]"
|
||||
"""
|
||||
from dllm.data.alpaca import load_dataset_alpaca
|
||||
from dllm.data.opc import load_dataset_opc_sft
|
||||
|
||||
specs = [p.strip() for p in dataset_args.split("|") if p.strip()]
|
||||
all_parts = []
|
||||
|
||||
for raw in specs:
|
||||
dataset_name_or_path, kvs = parse_spec(raw)
|
||||
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
dataset_name_or_path, "BASE_DATASETS_DIR"
|
||||
)
|
||||
|
||||
if load_preprocessed_data:
|
||||
logger.info("Load preprocessed data from disk.")
|
||||
ds = load_from_disk(dataset_name_or_path)
|
||||
# Implement your customized dataset here
|
||||
elif _match(dataset_name_or_path, "tatsu-lab/alpaca"):
|
||||
ds = load_dataset_alpaca(dataset_name_or_path)
|
||||
elif _match(dataset_name_or_path, "allenai/tulu-3-sft-mixture"):
|
||||
ds = load_dataset(dataset_name_or_path)
|
||||
ds = ds["train"].train_test_split(test_size=0.1, seed=42)
|
||||
elif _match(dataset_name_or_path, "HuggingFaceTB/smoltalk"):
|
||||
name = kvs.pop("name", "all")
|
||||
ds = load_dataset(dataset_name_or_path, name=name)
|
||||
elif _match(dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage1") or _match(
|
||||
dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage2"
|
||||
):
|
||||
name = kvs.pop("name", None)
|
||||
lang = kvs.pop("lang", None)
|
||||
ds = load_dataset_opc_sft(dataset_name_or_path, name=name, lang=lang)
|
||||
elif _match(dataset_name_or_path, "HuggingFaceH4/ultrachat_200k"):
|
||||
ds = load_dataset(dataset_name_or_path)
|
||||
ds = DatasetDict({"train": ds["train_sft"], "test": ds["test_sft"]})
|
||||
else:
|
||||
ds = load_dataset(dataset_name_or_path)
|
||||
|
||||
# Normalize to DatasetDict and apply per-split limits
|
||||
ds = _ensure_datasetdict(ds)
|
||||
ds = _truncate_dataset(ds, kvs)
|
||||
all_parts.append(ds)
|
||||
|
||||
# If only one part, return as DatasetDict
|
||||
if len(all_parts) == 1:
|
||||
return _ensure_datasetdict(all_parts[0])
|
||||
|
||||
# Merge all parts into a single DatasetDict
|
||||
merged = all_parts[0]
|
||||
for part in all_parts[1:]:
|
||||
merged = _merge_datasetdicts(merged, part)
|
||||
return _ensure_datasetdict(merged)
|
||||
|
||||
|
||||
def load_pt_dataset(
|
||||
dataset_args: str, streaming: bool = True, load_preprocessed_data: bool = False
|
||||
) -> DatasetDict | IterableDatasetDict:
|
||||
"""
|
||||
Examples of dataset_args:
|
||||
- "mlfoundations/dclm-baseline-1.0"
|
||||
- "OpenCoder-LLM/opc-fineweb-code-corpus"
|
||||
- "OpenCoder-LLM/opc-fineweb-math-corpus"
|
||||
- "OpenCoder-LLM/opc-annealing-corpus[lang:python]"
|
||||
- "wikitext[name:wikitext-103-v1}]"
|
||||
"""
|
||||
from dllm.data.opc import load_dataset_opc_annealing
|
||||
|
||||
specs = [p.strip() for p in dataset_args.split("|") if p.strip()]
|
||||
if not specs:
|
||||
raise ValueError("Empty dataset_args for load_pt_dataset.")
|
||||
|
||||
# ---------- Shared loader (only differs by streaming flag) ----------
|
||||
def _load_base_dataset(
|
||||
raw: str, *, streaming: bool
|
||||
) -> tuple[DatasetDict | IterableDatasetDict, dict, str]:
|
||||
"""
|
||||
Returns: (base, kvs, dataset_name_or_path)
|
||||
- Pops 'name' from kvs when applicable (e.g., wikitext).
|
||||
- Applies identical matching logic for both streaming/non-streaming.
|
||||
"""
|
||||
dataset_name_or_path, kvs = parse_spec(raw)
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
dataset_name_or_path, "BASE_DATASETS_DIR"
|
||||
)
|
||||
name = kvs.pop("name", None)
|
||||
|
||||
if load_preprocessed_data:
|
||||
base = load_from_disk(dataset_name_or_path)
|
||||
elif _match(dataset_name_or_path, ["OpenCoder-LLM/opc-annealing-corpus"]):
|
||||
lang = kvs.pop("lang", None)
|
||||
base = load_dataset_opc_annealing(
|
||||
dataset_name_or_path, name=name, lang=lang, streaming=streaming
|
||||
)
|
||||
else:
|
||||
base = load_dataset(dataset_name_or_path, name=name, streaming=streaming)
|
||||
|
||||
return base, kvs, dataset_name_or_path
|
||||
|
||||
# ---------- Streaming path ----------
|
||||
def _load_one_streaming_spec(raw: str) -> IterableDatasetDict:
|
||||
base, kvs, dataset_name_or_path = _load_base_dataset(raw, streaming=True)
|
||||
|
||||
split_names = list(base.keys())
|
||||
single_split = len(split_names) == 1
|
||||
single_split_name = split_names[0] if single_split else None
|
||||
|
||||
n_train = kvs.get("train")
|
||||
n_test = kvs.get("test")
|
||||
|
||||
if (n_train is not None) or (n_test is not None):
|
||||
if (n_train is not None) and (n_test is not None):
|
||||
if single_split:
|
||||
stream = base[single_split_name]
|
||||
head = stream.take(n_train + n_test)
|
||||
test = head.take(n_test)
|
||||
train = head.skip(n_test).take(n_train)
|
||||
return IterableDatasetDict({"train": train, "test": test})
|
||||
else:
|
||||
if "train" not in base or "test" not in base:
|
||||
raise ValueError(
|
||||
f"{dataset_name_or_path}: require 'train' and 'test' splits for train+test limits."
|
||||
)
|
||||
train = base["train"].take(n_train)
|
||||
test = base["test"].take(n_test)
|
||||
return IterableDatasetDict({"train": train, "test": test})
|
||||
|
||||
if n_train is not None:
|
||||
if single_split:
|
||||
train = base[single_split_name].take(n_train)
|
||||
else:
|
||||
if "train" not in base:
|
||||
raise ValueError(
|
||||
f"{dataset_name_or_path}: missing 'train' split for train limit."
|
||||
)
|
||||
train = base["train"].take(n_train)
|
||||
return IterableDatasetDict({"train": train})
|
||||
|
||||
if n_test is not None:
|
||||
if single_split:
|
||||
test = base[single_split_name].take(n_test)
|
||||
else:
|
||||
if "test" not in base:
|
||||
raise ValueError(
|
||||
f"{dataset_name_or_path}: missing 'test' split for test limit."
|
||||
)
|
||||
test = base["test"].take(n_test)
|
||||
return IterableDatasetDict({"test": test})
|
||||
|
||||
return base # already an IterableDatasetDict
|
||||
|
||||
# ---------- Non-streaming path (mirror load_sft_dataset; NO shuffle) ----------
|
||||
def _load_one_nonstreaming_spec(raw: str) -> DatasetDict:
|
||||
base, kvs, _ = _load_base_dataset(raw, streaming=False)
|
||||
ds = _ensure_datasetdict(base) # normalize
|
||||
ds = _truncate_dataset(ds, kvs) # apply limits (train/test/...)
|
||||
return ds
|
||||
|
||||
# ---------- Load & Merge ----------
|
||||
if streaming:
|
||||
logger.info("Loading dataset in streaming mode.")
|
||||
parts = [_load_one_streaming_spec(raw) for raw in specs]
|
||||
merged = parts[0]
|
||||
for p in parts[1:]:
|
||||
merged = _merge_iterabledatasetdicts(merged, p)
|
||||
# repeat streaming dataset infinitely
|
||||
merged = IterableDatasetDict(
|
||||
{k: (v.repeat(None) if k == "train" else v) for k, v in merged.items()}
|
||||
)
|
||||
return merged
|
||||
else:
|
||||
logger.info("Loading dataset in non-streaming mode.")
|
||||
parts = [_load_one_nonstreaming_spec(raw) for raw in specs]
|
||||
if len(parts) == 1:
|
||||
return _ensure_datasetdict(parts[0])
|
||||
merged = parts[0]
|
||||
for p in parts[1:]:
|
||||
merged = _merge_datasetdicts(merged, p)
|
||||
return _ensure_datasetdict(merged)
|
||||
|
||||
|
||||
def _truncate_split(split_data, n: int):
|
||||
if n is None:
|
||||
return split_data
|
||||
try:
|
||||
if hasattr(split_data, "select"):
|
||||
# Hugging Face Dataset path
|
||||
total = getattr(split_data, "num_rows", None)
|
||||
if total is None:
|
||||
# some Dataset types expose len(...)
|
||||
total = len(split_data)
|
||||
idx = list(range(min(n, total)))
|
||||
return split_data.select(idx)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return split_data[:n]
|
||||
except Exception:
|
||||
# Last resort: iterate
|
||||
return type(split_data)(item for i, item in enumerate(split_data) if i < n)
|
||||
|
||||
|
||||
def _truncate_dataset(ds, limits: dict):
|
||||
"""
|
||||
Ensure and return a DatasetDict, truncating splits mentioned in `limits`.
|
||||
"""
|
||||
ds = _ensure_datasetdict(ds) # normalize first
|
||||
out = {}
|
||||
for split, data in ds.items():
|
||||
n = limits.get(split, None)
|
||||
out[split] = _truncate_split(data, n) if n is not None else data
|
||||
return DatasetDict(out)
|
||||
|
||||
|
||||
def _concat_splits(a, b):
|
||||
"""
|
||||
Concatenate two split objects (prefer 🤗 datasets).
|
||||
"""
|
||||
if a is b:
|
||||
return a
|
||||
if a is None:
|
||||
return b
|
||||
if b is None:
|
||||
return a
|
||||
|
||||
# Prefer datasets' concatenate_datasets when both are Datasets
|
||||
try:
|
||||
from datasets import concatenate_datasets
|
||||
|
||||
if isinstance(a, Dataset) and isinstance(b, Dataset):
|
||||
return concatenate_datasets([a, b])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallbacks
|
||||
try:
|
||||
return a + b
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return type(a)(list(a) + list(b))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise TypeError(
|
||||
f"Cannot concatenate split objects of types {type(a)} and {type(b)}"
|
||||
)
|
||||
|
||||
|
||||
def _merge_datasetdicts(d1, d2):
|
||||
"""
|
||||
Merge two DatasetDict-like mappings by concatenating splits present in either.
|
||||
Always returns a DatasetDict.
|
||||
"""
|
||||
d1 = _ensure_datasetdict(d1)
|
||||
d2 = _ensure_datasetdict(d2)
|
||||
all_splits = set(d1.keys()) | set(d2.keys())
|
||||
out = {}
|
||||
for split in all_splits:
|
||||
a = d1.get(split, None)
|
||||
b = d2.get(split, None)
|
||||
if a is None:
|
||||
out[split] = b
|
||||
elif b is None:
|
||||
out[split] = a
|
||||
else:
|
||||
out[split] = _concat_splits(a, b)
|
||||
return DatasetDict(out)
|
||||
|
||||
|
||||
def _ensure_datasetdict(ds):
|
||||
"""
|
||||
Normalize various loader outputs into a DatasetDict.
|
||||
- If loader returns a DatasetDict, return as is.
|
||||
- If loader returns a mapping (e.g., dict of splits), wrap into DatasetDict.
|
||||
- If loader returns a single Dataset/list/etc., assume it's 'train'.
|
||||
"""
|
||||
if isinstance(ds, DatasetDict):
|
||||
return ds
|
||||
if isinstance(ds, dict):
|
||||
# Try to convert each split value to a Dataset if they aren't already.
|
||||
# If they are already Datasets, DatasetDict will accept them directly.
|
||||
return DatasetDict(ds)
|
||||
# Single split -> assume train
|
||||
return DatasetDict({"train": ds})
|
||||
|
||||
|
||||
def _match(name: str, needle) -> bool:
|
||||
"""
|
||||
Returns True if `name` matches any of the provided needles.
|
||||
Accepts a single string or a list/tuple of strings.
|
||||
Match condition: name endswith(needle) or needle in name.
|
||||
"""
|
||||
if isinstance(needle, (list, tuple)):
|
||||
return any(name.endswith(n) or n in name for n in needle)
|
||||
return name.endswith(needle) or needle in name
|
||||
|
||||
|
||||
def _concat_iterable_datasets(parts: list[IterableDataset]) -> IterableDataset:
|
||||
"""
|
||||
Concatenate IterableDatasets sequentially without materialization.
|
||||
Preserves streaming nature; supports downstream .take()/.skip()/.shuffle().
|
||||
"""
|
||||
if not parts:
|
||||
raise ValueError("No IterableDatasets to concatenate.")
|
||||
# Try to reuse features from the first dataset when available
|
||||
features = getattr(parts[0], "features", None)
|
||||
|
||||
def _gen():
|
||||
for ds in parts:
|
||||
yield from ds
|
||||
|
||||
return IterableDataset.from_generator(_gen, features=features)
|
||||
|
||||
|
||||
def _ensure_iterabledatasetdict(obj) -> IterableDatasetDict:
|
||||
if isinstance(obj, IterableDatasetDict):
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
return IterableDatasetDict(obj)
|
||||
# Single stream -> assume train
|
||||
return IterableDatasetDict({"train": obj})
|
||||
|
||||
|
||||
def _merge_iterabledatasetdicts(
|
||||
d1: IterableDatasetDict, d2: IterableDatasetDict
|
||||
) -> IterableDatasetDict:
|
||||
"""
|
||||
Merge by concatenating any overlapping splits (streaming-safe).
|
||||
"""
|
||||
d1 = _ensure_iterabledatasetdict(d1)
|
||||
d2 = _ensure_iterabledatasetdict(d2)
|
||||
all_splits = set(d1.keys()) | set(d2.keys())
|
||||
out = {}
|
||||
for split in all_splits:
|
||||
a = d1.get(split, None)
|
||||
b = d2.get(split, None)
|
||||
if a is None:
|
||||
out[split] = b
|
||||
elif b is None:
|
||||
out[split] = a
|
||||
else:
|
||||
out[split] = _concat_iterable_datasets([a, b])
|
||||
return IterableDatasetDict(out)
|
||||
|
||||
|
||||
def _truncate_stream(ds: IterableDataset, n: int | None) -> IterableDataset:
|
||||
if n is None:
|
||||
return ds
|
||||
return ds.take(n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
breakpoint()
|
||||
1
dllm/dllm/pipelines/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import llada, dream, rnd, editflow
|
||||
362
dllm/dllm/pipelines/bert/eval.py
Normal file
@ -0,0 +1,362 @@
|
||||
"""
|
||||
accelerate launch \
|
||||
--num_processes 2 \
|
||||
dllm/pipelines/bert/eval.py \
|
||||
--tasks gsm8k \
|
||||
--batch_size 1 \
|
||||
--model bert \
|
||||
--device cuda \
|
||||
--num_fewshot 8 \
|
||||
--model_args "pretrained=dllm-collection/ModernBERT-base-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.api.registry import register_model
|
||||
from lm_eval.models.utils import get_dtype
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BERTEvalConfig(LLaDAGeneratorConfig):
|
||||
max_new_tokens: int = 128
|
||||
max_length: int = 512
|
||||
steps: int = 128
|
||||
block_length: int = 128
|
||||
|
||||
pretrained: str = ""
|
||||
dtype: str | torch.dtype = "auto"
|
||||
batch_size: int = 32
|
||||
mc_num: int = 128
|
||||
is_check_greedy: bool = True
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@register_model("bert")
|
||||
class BERTEvalHarness(LM):
|
||||
def __init__(
|
||||
self,
|
||||
config: BERTEvalConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Initialize config if not provided
|
||||
if config is None:
|
||||
config = BERTEvalConfig()
|
||||
|
||||
# Pull args from config, allow kwargs to override
|
||||
pretrained = kwargs.get("pretrained", config.pretrained)
|
||||
dtype = kwargs.get("dtype", config.dtype)
|
||||
batch_size = kwargs.get("batch_size", config.batch_size)
|
||||
mc_num = kwargs.get("mc_num", config.mc_num)
|
||||
is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy)
|
||||
device = kwargs.get("device", config.device)
|
||||
cfg = kwargs.get("cfg", config.cfg_scale)
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
block_length = kwargs.get("block_length", config.block_length)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
remasking = kwargs.get("remasking", config.remasking)
|
||||
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
# Get GLOBAL rank from torch.distributed (not accelerator)
|
||||
if torch.distributed.is_initialized():
|
||||
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
|
||||
self._world_size = (
|
||||
torch.distributed.get_world_size()
|
||||
) # ← GLOBAL world size (16)
|
||||
else:
|
||||
self._rank = 0
|
||||
self._world_size = 1
|
||||
|
||||
# Use accelerator for device placement
|
||||
self.model = dllm.utils.get_model(
|
||||
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# Let accelerator handle device placement
|
||||
self.model = accelerator.prepare(self.model)
|
||||
self.device = (
|
||||
accelerator.device
|
||||
) # ← Accelerator figures out local device correctly
|
||||
self.accelerator = accelerator
|
||||
else:
|
||||
# Single GPU
|
||||
self.model = self.model.to(device)
|
||||
self.device = torch.device(device)
|
||||
self.accelerator = None
|
||||
|
||||
self.tokenizer = dllm.utils.get_tokenizer(
|
||||
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
|
||||
)
|
||||
|
||||
# generation params
|
||||
self.mask_id = self.tokenizer.mask_token_id
|
||||
self.batch_size = int(batch_size)
|
||||
self.max_length = int(max_length)
|
||||
self.max_new_tokens = int(max_new_tokens)
|
||||
self.block_length = int(block_length)
|
||||
self.steps = int(steps)
|
||||
self.cfg = float(cfg)
|
||||
self.remasking = remasking
|
||||
self.is_check_greedy = is_check_greedy
|
||||
|
||||
# loglikelihood params
|
||||
self.mc_num = int(mc_num)
|
||||
assert mc_num % self.batch_size == 0
|
||||
self.sampling_eps = 0.0
|
||||
|
||||
def apply_chat_template(
|
||||
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Method to apply a chat template to a list of chat history between user and model.
|
||||
"""
|
||||
chat_templated = self.tokenizer.apply_chat_template(
|
||||
chat_history,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=not add_generation_prompt,
|
||||
)
|
||||
return chat_templated
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self.tokenizer.name_or_path.replace("/", "__")
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def _forward_process(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
b, l = batch.shape
|
||||
|
||||
target_len = (l - prompt_index.sum()).item()
|
||||
k = torch.randint(1, target_len + 1, (), device=batch.device)
|
||||
|
||||
x = torch.round(
|
||||
torch.linspace(
|
||||
float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device
|
||||
)
|
||||
).long()
|
||||
x = ((x - 1) % target_len) + 1
|
||||
assert x.min() >= 1 and x.max() <= target_len
|
||||
|
||||
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
|
||||
is_mask = indices < x.unsqueeze(1)
|
||||
|
||||
for i in range(b):
|
||||
is_mask[i] = is_mask[i][torch.randperm(target_len)]
|
||||
|
||||
is_mask = torch.cat(
|
||||
(
|
||||
torch.zeros(
|
||||
b, prompt_index.sum(), dtype=torch.bool, device=batch.device
|
||||
),
|
||||
is_mask,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
noisy_batch = torch.where(is_mask, self.mask_id, batch)
|
||||
|
||||
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_logits(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if self.cfg > 0.0:
|
||||
assert len(prompt_index) == batch.shape[1]
|
||||
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
||||
un_batch = batch.clone()
|
||||
un_batch[prompt_index] = self.mask_id
|
||||
batch = torch.cat([batch, un_batch])
|
||||
|
||||
logits = self.model(batch).logits
|
||||
|
||||
if self.cfg > 0.0:
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
|
||||
return logits[:, : batch.shape[1]]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
|
||||
seq = torch.concatenate([prefix, target])[None, :]
|
||||
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
|
||||
loss_acc = []
|
||||
for _ in range(self.mc_num // self.batch_size):
|
||||
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
|
||||
|
||||
mask_indices = perturbed_seq == self.mask_id
|
||||
|
||||
logits = self.get_logits(perturbed_seq, prompt_index)
|
||||
|
||||
loss = (
|
||||
F.cross_entropy(
|
||||
logits[mask_indices], seq[mask_indices], reduction="none"
|
||||
)
|
||||
/ p_mask[mask_indices]
|
||||
)
|
||||
loss = loss.sum() / self.batch_size
|
||||
loss_acc.append(loss.item())
|
||||
|
||||
return -sum(loss_acc) / len(loss_acc)
|
||||
|
||||
@torch.no_grad()
|
||||
def suffix_greedy_prediction(
|
||||
self, prefix: torch.Tensor, target: torch.Tensor
|
||||
) -> bool:
|
||||
if not self.is_check_greedy:
|
||||
return False
|
||||
|
||||
seq = torch.full(
|
||||
(1, len(prefix) + len(target)), self.mask_id, device=self.device
|
||||
)
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
prefix, target = prefix.to(self.device), target.to(self.device)
|
||||
seq[0, : len(prefix)] = prefix
|
||||
|
||||
for i in range(len(target)):
|
||||
mask_index = seq == self.mask_id
|
||||
logits = self.get_logits(seq, prompt_index)[mask_index]
|
||||
x0 = torch.argmax(logits, dim=-1)
|
||||
|
||||
p = torch.softmax(logits.to(torch.float32), dim=-1)
|
||||
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(
|
||||
dim=-1
|
||||
)
|
||||
_, index = torch.sort(confidence, descending=True)
|
||||
x0[index[1:]] = self.mask_id
|
||||
seq[mask_index] = x0.clone()
|
||||
correct = target == seq[0, len(prefix) :]
|
||||
correct = torch.all(correct)
|
||||
return correct
|
||||
|
||||
def _encode_pair(
|
||||
self, context: str, continuation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
n_spaces = len(context) - len(context.rstrip())
|
||||
if n_spaces > 0:
|
||||
continuation = context[-n_spaces:] + continuation
|
||||
context = context[:-n_spaces]
|
||||
|
||||
whole_enc = self.tokenizer(context + continuation)["input_ids"]
|
||||
context_enc = self.tokenizer(context)["input_ids"]
|
||||
|
||||
context_enc_len = len(context_enc)
|
||||
continuation_enc = whole_enc[context_enc_len:]
|
||||
|
||||
return context_enc, continuation_enc
|
||||
|
||||
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
||||
def _tokenize(e):
|
||||
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
||||
return {
|
||||
"prefix_text": e["prefix"],
|
||||
"target_text": e["target"],
|
||||
"prefix": prefix,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
ds = []
|
||||
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
|
||||
|
||||
assert max(prompt_len) <= 4096
|
||||
|
||||
out = []
|
||||
with torch.no_grad():
|
||||
for elem in tqdm(ds, desc="Computing likelihood..."):
|
||||
prefix = elem["prefix"]
|
||||
target = elem["target"]
|
||||
|
||||
ll = self.get_loglikelihood(prefix, target)
|
||||
|
||||
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
|
||||
|
||||
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
||||
torch.cuda.empty_cache()
|
||||
return out
|
||||
|
||||
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_until(self, requests: list[Instance]):
|
||||
def _tokenize(e):
|
||||
return {
|
||||
"question": self.tokenizer(e["question"])["input_ids"],
|
||||
"question_text": e["question"],
|
||||
"until": e["until"],
|
||||
}
|
||||
|
||||
ds = [
|
||||
{"question": req.args[0], "until": req.args[1]["until"]} for req in requests
|
||||
]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
|
||||
out = []
|
||||
generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer)
|
||||
|
||||
for elem in tqdm(ds, desc="Generating..."):
|
||||
prompt = [elem["question"][1:-1].to(self.device)]
|
||||
stop_tokens = elem["until"]
|
||||
generated_ids = generator.generate(
|
||||
inputs=prompt,
|
||||
steps=self.steps,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
block_length=self.block_length,
|
||||
temperature=0.0,
|
||||
cfg_scale=self.cfg,
|
||||
remasking=self.remasking,
|
||||
)
|
||||
generated_answer = self.tokenizer.decode(
|
||||
generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False
|
||||
)
|
||||
breakpoint()
|
||||
for stop_seq in stop_tokens:
|
||||
if stop_seq in generated_answer:
|
||||
generated_answer = generated_answer.split(stop_seq)[0]
|
||||
|
||||
# remove special tokens
|
||||
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
|
||||
generated_answer = self.tokenizer.decode(
|
||||
generated_answer_ids, skip_special_tokens=True
|
||||
)
|
||||
out.append(generated_answer)
|
||||
if self.accelerator is not None:
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_evaluate()
|
||||
6
dllm/dllm/pipelines/dream/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from . import generator, models, trainer, utils
|
||||
from .models.modeling_dream import DreamModel
|
||||
from .models.configuration_dream import DreamConfig
|
||||
from .models.tokenization_dream import DreamTokenizer
|
||||
from .generator import DreamGeneratorConfig, DreamGenerator
|
||||
from .trainer import DreamTrainer
|
||||
533
dllm/dllm/pipelines/dream/eval.py
Normal file
@ -0,0 +1,533 @@
|
||||
"""
|
||||
accelerate launch \
|
||||
--num_processes 2 \
|
||||
dllm/pipelines/dream/eval.py \
|
||||
--tasks gsm8k \
|
||||
--batch_size 1 \
|
||||
--model dream \
|
||||
--device cuda
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=Dream-org/Dream-v0-Base-7B,mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
"""
|
||||
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.api.registry import register_model
|
||||
from lm_eval.models.utils import get_dtype
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines.dream import DreamGenerator, DreamGeneratorConfig
|
||||
|
||||
eval_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamEvalConfig(DreamGeneratorConfig):
|
||||
top_p: float | None = None
|
||||
top_k: float | None = None
|
||||
max_new_tokens: int = 128
|
||||
max_length: int = 2048
|
||||
steps: int = 128
|
||||
temperature: float = 0.0
|
||||
alg: str = "entropy"
|
||||
|
||||
pretrained: str = ""
|
||||
batch_size: int = 1
|
||||
device: str = "cuda"
|
||||
dtype: str | torch.dtype = "auto"
|
||||
add_bos_token: bool = False
|
||||
nll_type: str = "mc"
|
||||
log_type: str = "ftb"
|
||||
mc_num: int = 128
|
||||
classifier_free_guidance: float = 1.0
|
||||
sampling_eps: float = 1e-3
|
||||
escape_until: bool = False
|
||||
|
||||
|
||||
@register_model("dream")
|
||||
class DreamEvalHarness(LM):
|
||||
def __init__(
|
||||
self,
|
||||
config: DreamEvalConfig | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Initialize config if not provided
|
||||
if config is None:
|
||||
config = DreamEvalConfig()
|
||||
|
||||
# Pull args from config, allow kwargs to override
|
||||
pretrained = kwargs.get("pretrained", config.pretrained)
|
||||
batch_size = kwargs.get("batch_size", config.batch_size)
|
||||
device = kwargs.get("device", config.device)
|
||||
dtype = kwargs.get("dtype", config.dtype)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
add_bos_token = kwargs.get("add_bos_token", config.add_bos_token)
|
||||
nll_type = kwargs.get("nll_type", config.nll_type)
|
||||
log_type = kwargs.get("log_type", config.log_type)
|
||||
mc_num = kwargs.get("mc_num", config.mc_num)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
classifier_free_guidance = kwargs.get(
|
||||
"classifier_free_guidance", config.classifier_free_guidance
|
||||
)
|
||||
sampling_eps = kwargs.get("sampling_eps", config.sampling_eps)
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
top_p = kwargs.get("top_p", config.top_p)
|
||||
top_k = kwargs.get("top_k", config.top_k)
|
||||
alg = kwargs.get("alg", config.alg)
|
||||
alg_temp = kwargs.get("alg_temp", config.alg_temp)
|
||||
escape_until = kwargs.get("escape_until", config.escape_until)
|
||||
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
# Get GLOBAL rank from torch.distributed (not accelerator)
|
||||
if torch.distributed.is_initialized():
|
||||
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
|
||||
self._world_size = (
|
||||
torch.distributed.get_world_size()
|
||||
) # ← GLOBAL world size (16)
|
||||
else:
|
||||
self._rank = 0
|
||||
self._world_size = 1
|
||||
|
||||
# Use accelerator for device placement
|
||||
self.model = dllm.utils.get_model(
|
||||
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# Let accelerator handle device placement
|
||||
self.model = accelerator.prepare(self.model)
|
||||
self.device = (
|
||||
accelerator.device
|
||||
) # ← Accelerator figures out local device correctly
|
||||
self.accelerator = accelerator
|
||||
else:
|
||||
# Single GPU
|
||||
self.model = self.model.to(device)
|
||||
self.device = torch.device(device)
|
||||
self.accelerator = None
|
||||
|
||||
self.tokenizer = dllm.utils.get_tokenizer(
|
||||
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
|
||||
)
|
||||
|
||||
# generation params
|
||||
self.mask_id = self.tokenizer.mask_token_id
|
||||
self.max_length = max_length
|
||||
self.add_bos_token = add_bos_token
|
||||
self.batch_size = int(batch_size)
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.steps = steps
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.alg = alg
|
||||
self.alg_temp = alg_temp
|
||||
self.escape_until = escape_until
|
||||
|
||||
# loglikelihood params
|
||||
self.nll_type = nll_type
|
||||
self.log_type = log_type
|
||||
self.mc_num = mc_num
|
||||
self.classifier_free_guidance = classifier_free_guidance
|
||||
self.sampling_eps = sampling_eps
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def tok_decode(
|
||||
self, tokens: torch.Tensor | list[int], skip_special_tokens: bool = True
|
||||
) -> str:
|
||||
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def tok_encode(self, text: str, add_special_tokens: bool = True) -> torch.Tensor:
|
||||
return self.tokenizer(
|
||||
text, return_tensors="pt", add_special_tokens=add_special_tokens
|
||||
).input_ids
|
||||
|
||||
def apply_chat_template(
|
||||
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Method to apply a chat template to a list of chat history between user and model.
|
||||
"""
|
||||
chat_templated = self.tokenizer.apply_chat_template(
|
||||
chat_history,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=not add_generation_prompt,
|
||||
)
|
||||
return chat_templated
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self.tokenizer.name_or_path.replace("/", "__")
|
||||
|
||||
def generate_until(
|
||||
self, requests: list[Instance], disable_tqdm: bool = False
|
||||
) -> list[str]:
|
||||
res = []
|
||||
pbar = tqdm(
|
||||
total=len(requests),
|
||||
disable=(disable_tqdm or (self.rank != 0)),
|
||||
desc="Running generate_until requests",
|
||||
)
|
||||
generator = DreamGenerator(model=self.model, tokenizer=self.tokenizer)
|
||||
for batch_idx in range(0, len(requests), self.batch_size):
|
||||
batch_requests = requests[batch_idx : batch_idx + self.batch_size]
|
||||
contexts, gen_args = zip(*[req.arguments for req in batch_requests])
|
||||
|
||||
# ====== BEGIN merged _generate_batch logic ======
|
||||
prompts = list(contexts)
|
||||
if self.add_bos_token:
|
||||
prompts = [self.tokenizer.bos_token + p for p in prompts]
|
||||
|
||||
# tokenize
|
||||
prompt_ids = [
|
||||
self.tokenizer(
|
||||
p, return_tensors="pt", padding=False
|
||||
).input_ids.squeeze()
|
||||
for p in prompts
|
||||
]
|
||||
prompt_lens = [len(p_id) for p_id in prompt_ids]
|
||||
|
||||
if max(prompt_lens) > self.max_length - self.max_new_tokens:
|
||||
cutoff_len = self.max_length - self.max_new_tokens
|
||||
eval_logger.warning(
|
||||
f"Prompt length {max(prompt_lens)} exceeds {cutoff_len}, cutoff on the left side"
|
||||
)
|
||||
# ✅ Correct: trim from the left side (keep the last cutoff_len tokens)
|
||||
prompt_ids = [p_id[-cutoff_len:] for p_id in prompt_ids]
|
||||
|
||||
# generation
|
||||
generation_ids = generator.generate(
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
inputs=prompt_ids,
|
||||
steps=self.steps,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
alg=self.alg,
|
||||
alg_temp=self.alg_temp,
|
||||
output_history=False,
|
||||
return_dict_in_generate=False,
|
||||
)
|
||||
# decode and cleanup
|
||||
cleaned_generation_ids = [
|
||||
(
|
||||
seq[seq.ne(self.tokenizer.eos_token_id).float().argmax().long() :]
|
||||
if (seq != self.tokenizer.eos_token_id).any()
|
||||
else seq[-1:]
|
||||
)
|
||||
for seq in generation_ids
|
||||
]
|
||||
truncated_generation_ids = [
|
||||
seq[prompt_lens[i] :] for i, seq in enumerate(cleaned_generation_ids)
|
||||
]
|
||||
responses = [
|
||||
g.lstrip("<|endoftext|>").split(self.tokenizer.eos_token, 1)[0]
|
||||
for g in self.tokenizer.batch_decode(truncated_generation_ids)
|
||||
]
|
||||
|
||||
# ====== END merged _generate_batch logic ======
|
||||
|
||||
# handle "until" truncation
|
||||
if not self.escape_until:
|
||||
for i, r in enumerate(responses):
|
||||
for s in gen_args[0]["until"]:
|
||||
r = r.split(s)[0]
|
||||
responses[i] = r
|
||||
|
||||
res.extend(responses)
|
||||
pbar.update(len(contexts))
|
||||
|
||||
return res
|
||||
|
||||
def _forward_process(
|
||||
self, batch: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
b, l = batch.shape
|
||||
# sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
|
||||
u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
|
||||
indices = torch.arange(b, device=batch.device).float()
|
||||
t = (u0 + indices / b) % 1
|
||||
|
||||
p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
|
||||
|
||||
p_mask = p_mask[:, None].repeat(1, l)
|
||||
|
||||
mask_indices = torch.rand((b, l), device=batch.device) < p_mask
|
||||
# always unmask bos and eos
|
||||
mask_indices[:, 0] = False
|
||||
mask_indices[:, -1] = False
|
||||
|
||||
noisy_batch = torch.where(mask_indices, self.mask_id, batch)
|
||||
return noisy_batch, p_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def get_logits(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
prompt_index : 1D bool tensor, length=batch.shape[1]
|
||||
"""
|
||||
if self.classifier_free_guidance > 1.0:
|
||||
assert len(prompt_index) == batch.shape[1]
|
||||
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
||||
un_batch = batch.clone()
|
||||
un_batch[prompt_index] = self.mask_id
|
||||
batch = torch.cat([batch, un_batch])
|
||||
|
||||
input = batch
|
||||
|
||||
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
||||
logits = self.model(input).logits
|
||||
# since bos always unmask, the first logits will not be used
|
||||
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
|
||||
if self.classifier_free_guidance > 1.0:
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + self.cfg * (logits - un_logits)
|
||||
return logits[:, : batch.shape[1]]
|
||||
|
||||
@torch.no_grad()
|
||||
def _eval_target_nll_mc(
|
||||
self, prefix: torch.Tensor | None, target: torch.Tensor
|
||||
) -> float:
|
||||
if prefix is None:
|
||||
seq = target[None, :]
|
||||
else:
|
||||
seq = torch.concatenate([prefix, target])[None, :]
|
||||
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
||||
|
||||
if self.log_type == "ftb":
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
else:
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
|
||||
|
||||
loss_acc = []
|
||||
for _ in range(max(self.mc_num // self.batch_size, 1)):
|
||||
perturbed_seq = seq.clone()
|
||||
# eval_logger.info("before noising")
|
||||
perturbed_seq_, p_mask = self._forward_process(seq)
|
||||
# eval_logger.info("end noising")
|
||||
if self.log_type == "ftb":
|
||||
perturbed_seq[:, -len(target) :] = perturbed_seq_[:, -len(target) :]
|
||||
elif self.log_type == "btf":
|
||||
perturbed_seq[:, : len(prefix)] = perturbed_seq_[:, : len(prefix)]
|
||||
elif self.log_type == "union":
|
||||
perturbed_seq = perturbed_seq_
|
||||
else:
|
||||
raise NotImplementedError(self.log_type)
|
||||
|
||||
mask_indices = perturbed_seq == self.mask_id
|
||||
logits = self.get_logits(perturbed_seq, prompt_index)
|
||||
loss = (
|
||||
F.cross_entropy(
|
||||
logits[mask_indices], seq[mask_indices], reduction="none"
|
||||
)
|
||||
/ p_mask[mask_indices]
|
||||
)
|
||||
loss = loss.sum() / self.batch_size
|
||||
loss_acc.append(loss.item())
|
||||
|
||||
return sum(loss_acc) / len(loss_acc)
|
||||
|
||||
@torch.no_grad()
|
||||
def _eval_target_nll_ar(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
|
||||
prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
|
||||
assert self.log_type in ["ftb", "btf"]
|
||||
assert self.nll_type in ["ar_ftb", "ar_btf"]
|
||||
|
||||
if self.log_type == "ftb":
|
||||
prompt_index = (
|
||||
torch.arange(prefix.shape[1] + target.shape[1], device=self.device)
|
||||
< prefix.shape[1]
|
||||
)
|
||||
else:
|
||||
prompt_index = (
|
||||
torch.arange(prefix.shape[1] + target.shape[1], device=self.device)
|
||||
>= prefix.shape[1]
|
||||
)
|
||||
|
||||
if self.log_type == "ftb":
|
||||
perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
|
||||
else:
|
||||
perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
|
||||
|
||||
mask_index = torch.ones(
|
||||
(perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool
|
||||
)
|
||||
if self.nll_type == "ar_ftb":
|
||||
mask_index = torch.triu(mask_index)
|
||||
else:
|
||||
mask_index = torch.tril(mask_index)
|
||||
perturbed_[mask_index] = self.mask_id
|
||||
if self.log_type == "ftb":
|
||||
perturbed_seq = torch.cat(
|
||||
[prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1
|
||||
)
|
||||
else:
|
||||
perturbed_seq = torch.cat(
|
||||
[perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1
|
||||
)
|
||||
|
||||
logits_ = []
|
||||
num = (
|
||||
len(perturbed_seq) // self.batch_size
|
||||
if len(perturbed_seq) % self.batch_size == 0
|
||||
else len(perturbed_seq) // self.batch_size + 1
|
||||
)
|
||||
for i in range(num):
|
||||
end = (
|
||||
(i + 1) * self.batch_size
|
||||
if (i + 1) * self.batch_size < len(perturbed_seq)
|
||||
else len(perturbed_seq)
|
||||
)
|
||||
perturbed_seq_ = perturbed_seq[i * self.batch_size : end]
|
||||
perturbed_seq_ = perturbed_seq_.to(self.device)
|
||||
if len(perturbed_seq_.shape) == 1:
|
||||
perturbed_seq_ = perturbed_seq_.unsqueeze(0)
|
||||
logits = self.get_logits(perturbed_seq_, prompt_index)
|
||||
logits_.append(logits.cpu())
|
||||
logits = torch.cat(logits_, dim=0)
|
||||
|
||||
temp_index = torch.ones(
|
||||
(perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool
|
||||
)
|
||||
if self.nll_type == "ar_ftb":
|
||||
temp_index = torch.triu(temp_index, diagonal=1)
|
||||
else:
|
||||
temp_index = torch.tril(temp_index, diagonal=-1)
|
||||
mask_index[temp_index] = False
|
||||
if self.log_type == "ftb":
|
||||
logits_index = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool
|
||||
),
|
||||
mask_index,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
logits_index = torch.cat(
|
||||
[
|
||||
mask_index,
|
||||
torch.zeros(
|
||||
(perturbed_.shape[1], target.shape[1]), dtype=torch.bool
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if self.log_type == "ftb":
|
||||
loss = (
|
||||
F.cross_entropy(logits[logits_index], target[0], reduction="sum")
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
else:
|
||||
loss = (
|
||||
F.cross_entropy(logits[logits_index], prefix[0], reduction="sum")
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
return loss
|
||||
|
||||
def _encode_pair(
|
||||
self, context: str, continuation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.add_bos_token:
|
||||
context = self.tokenizer.bos_token + context
|
||||
|
||||
n_spaces = len(context) - len(context.rstrip())
|
||||
if n_spaces > 0:
|
||||
continuation = context[-n_spaces:] + continuation
|
||||
context = context[:-n_spaces]
|
||||
|
||||
whole_enc = self.tokenizer.encode(context + continuation) + [
|
||||
self.tokenizer.eos_token_id
|
||||
]
|
||||
context_enc = self.tokenizer.encode(context)
|
||||
|
||||
context_enc_len = len(context_enc)
|
||||
continuation_enc = whole_enc[context_enc_len:]
|
||||
|
||||
# by default truncate on the left
|
||||
cutoff_length = max(len(whole_enc) - self.max_length, 0)
|
||||
if cutoff_length > 0:
|
||||
eval_logger.warning(
|
||||
f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side"
|
||||
)
|
||||
context_remain = context_enc_len - cutoff_length
|
||||
if context_remain > 0:
|
||||
context_enc = context_enc[-context_remain:]
|
||||
else:
|
||||
eval_logger.warning(f"All context (prompt) is truncated.")
|
||||
context_enc = ""
|
||||
continuation_enc = whole_enc[-self.max_length :]
|
||||
return context_enc, continuation_enc
|
||||
|
||||
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
||||
def _tokenize(e):
|
||||
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
||||
return {
|
||||
"prefix_text": e["prefix"],
|
||||
"target_text": e["target"],
|
||||
"prefix": prefix,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
ds = []
|
||||
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
|
||||
out = []
|
||||
with torch.no_grad():
|
||||
for elem in tqdm(ds, desc="Computing likelihood..."):
|
||||
prefix = elem["prefix"]
|
||||
target = elem["target"]
|
||||
# likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
|
||||
if self.nll_type == "mc":
|
||||
ll = -self._eval_target_nll_mc(prefix, target)
|
||||
if self.log_type == "union":
|
||||
ll = ll / (len(target) + len(prefix))
|
||||
elif self.nll_type == "ar_ftb" or self.nll_type == "ar_btf":
|
||||
ll = -self._eval_target_nll_ar(prefix, target)
|
||||
else:
|
||||
raise NotImplementedError(self.nll_type)
|
||||
|
||||
# TODO: greedy decoding
|
||||
is_target_greedy_dec = False
|
||||
|
||||
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
||||
return out
|
||||
|
||||
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_evaluate()
|
||||
426
dllm/dllm/pipelines/dream/generator.py
Normal file
@ -0,0 +1,426 @@
|
||||
"""
|
||||
reference: https://huggingface.co/Dream-org/Dream-v0-Base-7B/blob/main/generation_utils.py
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributions as dists
|
||||
|
||||
from dllm.utils.generation_utils import get_num_transfer_tokens
|
||||
from dllm.pipelines.dream.utils import top_p_logits, top_k_logits
|
||||
from dllm.core.generation.generator import (
|
||||
GeneratorOutput,
|
||||
GeneratorConfig,
|
||||
BaseGenerator,
|
||||
)
|
||||
|
||||
|
||||
def sample_tokens(
|
||||
logits: torch.Tensor,
|
||||
temperature: float = 0.0,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
margin_confidence: bool = False,
|
||||
neg_entropy: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if temperature > 0:
|
||||
logits = logits / temperature
|
||||
if top_p is not None and top_p < 1:
|
||||
logits = top_p_logits(logits, top_p)
|
||||
if top_k is not None:
|
||||
logits = top_k_logits(logits, top_k)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if temperature > 0:
|
||||
try:
|
||||
x0 = dists.Categorical(probs=probs).sample()
|
||||
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
||||
except Exception:
|
||||
confidence, x0 = probs.max(dim=-1)
|
||||
else:
|
||||
confidence, x0 = probs.max(dim=-1)
|
||||
|
||||
if margin_confidence:
|
||||
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
||||
top1_probs = sorted_probs[:, 0]
|
||||
top2_probs = sorted_probs[:, 1]
|
||||
confidence = top1_probs - top2_probs
|
||||
|
||||
if neg_entropy:
|
||||
epsilon = 1e-10
|
||||
log_probs = torch.log(probs + epsilon)
|
||||
confidence = torch.sum(probs * log_probs, dim=-1)
|
||||
|
||||
return confidence, x0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamGeneratorConfig(GeneratorConfig):
|
||||
max_new_tokens: int = 20
|
||||
max_length: int = (
|
||||
None # The max_length is set as input_ids.shape[1] + 20: generation_config.max_length = generation_config.max_length + input_ids_length
|
||||
)
|
||||
steps: int = 512
|
||||
eps: float = 1e-3
|
||||
alg: str = "origin"
|
||||
alg_temp: float = 0.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = 50
|
||||
stochastic_transfer: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamGenerator(BaseGenerator):
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: list[torch.Tensor, list],
|
||||
config: DreamGeneratorConfig | None = None,
|
||||
generation_tokens_hook_func=lambda step, x, logits: x,
|
||||
generation_logits_hook_func=lambda step, x, logits: logits,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput | torch.Tensor:
|
||||
"""
|
||||
Diffusion-style masked decoding for *generation from inputs*.
|
||||
(docstring unchanged)
|
||||
"""
|
||||
if config is None:
|
||||
config = DreamGeneratorConfig()
|
||||
|
||||
# ----- pull args from config, allow kwargs to override -----
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
eps = kwargs.get("eps", config.eps)
|
||||
alg = kwargs.get("alg", config.alg)
|
||||
alg_temp = kwargs.get("eps", config.alg_temp)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
top_p = kwargs.get("top_p", config.top_p)
|
||||
top_k = kwargs.get("top_k", config.top_k)
|
||||
stochastic_transfer = kwargs.get(
|
||||
"stochastic_transfer", config.stochastic_transfer
|
||||
)
|
||||
# generation_tokens_hook_func = kwargs.get("generation_tokens_hook_func", config.generation_tokens_hook_func)
|
||||
# generation_logits_hook_func = kwargs.get("generation_logits_hook_func", config.generation_logits_hook_func)
|
||||
return_dict_in_generate = kwargs.get(
|
||||
"return_dict_in_generate", config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# --- Initialization ---
|
||||
mask_token_id = self.tokenizer.mask_token_id
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = [
|
||||
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
|
||||
for p in inputs
|
||||
]
|
||||
prompt_lens = [p.shape[0] for p in inputs]
|
||||
if max_new_tokens:
|
||||
max_length = max_new_tokens + max(prompt_lens)
|
||||
else:
|
||||
max_new_tokens = max_length - max(prompt_lens)
|
||||
|
||||
B = len(inputs)
|
||||
T = max_length
|
||||
x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device)
|
||||
|
||||
seq_length = []
|
||||
for i, p in enumerate(inputs):
|
||||
total_len = prompt_lens[i] + max_new_tokens
|
||||
seq_length.append(total_len)
|
||||
start = T - total_len
|
||||
x[i, start : start + prompt_lens[i]] = p
|
||||
x[i, start + prompt_lens[i] : T] = mask_token_id
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
(B, T), dtype=torch.float32, device=self.model.device
|
||||
)
|
||||
for j, L in enumerate(seq_length):
|
||||
if L > 0:
|
||||
attention_mask[j, -L:] = 1.0 # Mandate to be left-padding
|
||||
|
||||
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
||||
pos_id = attention_mask.long().cumsum(-1) - 1
|
||||
pos_id.masked_fill_(attention_mask == 0, 1)
|
||||
else:
|
||||
pos_id = None
|
||||
|
||||
mask_index = x == mask_token_id
|
||||
num_transfer_tokens_list = get_num_transfer_tokens(
|
||||
mask_index=mask_index,
|
||||
steps=steps,
|
||||
scheduler=self.scheduler,
|
||||
stochastic=stochastic_transfer,
|
||||
)
|
||||
effective_steps = num_transfer_tokens_list.size(1)
|
||||
|
||||
# --- Iterative refinement ---
|
||||
x = generation_tokens_hook_func(None, x, None)
|
||||
histories = [x.clone()] if return_dict_in_generate else None
|
||||
for i in range(effective_steps):
|
||||
mask_index = x == mask_token_id
|
||||
|
||||
logits = self.model(x, attention_mask, pos_id).logits
|
||||
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
logits = generation_logits_hook_func(i, x, logits)
|
||||
|
||||
mask_logits = logits[mask_index]
|
||||
|
||||
if alg == "maskgit_plus":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
|
||||
)
|
||||
elif alg == "topk_margin":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
margin_confidence=True,
|
||||
)
|
||||
elif alg == "entropy":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
neg_entropy=True,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown alg: {alg}")
|
||||
|
||||
full_confidence = torch.full_like(
|
||||
x, -torch.inf, device=self.model.device, dtype=logits.dtype
|
||||
)
|
||||
full_confidence[mask_index] = confidence
|
||||
|
||||
for j in range(full_confidence.shape[0]):
|
||||
number_transfer_tokens = num_transfer_tokens_list[j, i]
|
||||
if number_transfer_tokens > 0:
|
||||
if alg_temp is None or alg_temp == 0:
|
||||
_, transfer_index = torch.topk(
|
||||
full_confidence[j], number_transfer_tokens
|
||||
)
|
||||
else:
|
||||
fc = full_confidence[j] / alg_temp
|
||||
fc = F.softmax(fc, dim=-1)
|
||||
transfer_index = torch.multinomial(
|
||||
fc, num_samples=number_transfer_tokens
|
||||
)
|
||||
|
||||
x_ = torch.full_like(x, mask_token_id, device=self.model.device)
|
||||
x_[mask_index] = x0.clone()
|
||||
x[j, transfer_index] = x_[j, transfer_index]
|
||||
|
||||
x = generation_tokens_hook_func(i, x, logits)
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
if not return_dict_in_generate:
|
||||
return x
|
||||
else:
|
||||
return GeneratorOutput(sequences=x, histories=histories)
|
||||
|
||||
@torch.no_grad()
|
||||
def infill(
|
||||
self,
|
||||
inputs: list[torch.Tensor, list],
|
||||
config,
|
||||
generation_tokens_hook_func=lambda step, x, logits: x,
|
||||
generation_logits_hook_func=lambda step, x, logits: logits,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput | torch.Tensor:
|
||||
"""
|
||||
Fill in-place the tokenizer's `<mask>` tokens contained in `inputs`.
|
||||
The whole (right-aligned) canvas is denoised iteratively: at each step, a scheduler
|
||||
decides how many masked positions to commit, and a confidence rule (`alg`)
|
||||
selects *which* positions to reveal (MaskGIT-style). Non-mask tokens are never changed.
|
||||
|
||||
High-level:
|
||||
1) Build a right-aligned canvas per sample (left side padded with EOS).
|
||||
2) Compute a per-sample transfer schedule via `scheduler` and `steps`.
|
||||
3) At each step: forward pass → AR-shift logits → score masked positions
|
||||
via `alg` → choose indices to commit (top-k or soft sampling) → write tokens.
|
||||
|
||||
Notes:
|
||||
- Right padding uses EOS (serves as pad here).
|
||||
- Only `[MASK]` positions are updated; original tokens remain intact.
|
||||
- Logits are AR-shifted to preserve next-token prediction alignment.
|
||||
|
||||
Args:
|
||||
model:
|
||||
Mask predictor; returns logits of shape [B, T, V] when called as
|
||||
`model(x, attention_mask, pos_id)`.
|
||||
tokenizer:
|
||||
Must provide `mask_token_id` and `eos_token_id`.
|
||||
inputs:
|
||||
List of 1D LongTensors (token ids). Each may contain `<mask>` tokens
|
||||
to be filled; other tokens are treated as fixed context.
|
||||
scheduler (BaseAlphaScheduler):
|
||||
Controls how many masks to commit per step (deterministic or stochastic).
|
||||
generation_tokens_hook_func / generation_logits_hook_func:
|
||||
Optional hooks to intercept tokens/logits at each step.
|
||||
output_history (bool):
|
||||
If True, save intermediate canvases at each step.
|
||||
return_dict_in_generate (bool):
|
||||
If True, return `DreamModelOutput(sequences, history)`, else only `[B, T]`.
|
||||
steps (int):
|
||||
Total reverse-diffusion steps (quality–speed trade-off).
|
||||
alg (str):
|
||||
Confidence rule to rank masked positions:
|
||||
- "maskgit_plus": softmax probs
|
||||
- "topk_margin": top1 - top2 margin
|
||||
- "entropy": negative entropy
|
||||
alg_temp (float):
|
||||
Temperature for *confidence-based index sampling* (when > 0, soft selection).
|
||||
temperature / top_p / top_k:
|
||||
Token sampling hyperparameters within `sample_tokens`.
|
||||
stochastic_transfer (bool):
|
||||
If True, sample the number of transfers per step (Binomial); else use expectation.
|
||||
|
||||
Returns:
|
||||
DreamModelOutput | torch.LongTensor:
|
||||
If `return_dict_in_generate=True`, returns
|
||||
- sequences: `[B, T]` final tokens
|
||||
- history: optional list of intermediate canvases
|
||||
Otherwise returns only `[B, T]`.
|
||||
"""
|
||||
# ----- pull args from config, allow kwargs to override -----
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
eps = kwargs.get("eps", config.eps)
|
||||
alg = kwargs.get("alg", config.alg)
|
||||
alg_temp = kwargs.get("eps", config.alg_temp)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
top_p = kwargs.get("top_p", config.top_p)
|
||||
top_k = kwargs.get("top_k", config.top_k)
|
||||
stochastic_transfer = kwargs.get(
|
||||
"stochastic_transfer", config.stochastic_transfer
|
||||
)
|
||||
# generation_tokens_hook_func = kwargs.get("stochastic_transfer", config.generation_tokens_hook_func)
|
||||
# generation_logits_hook_func = kwargs.get("stochastic_transfer", config.generation_logits_hook_func)
|
||||
return_dict_in_generate = kwargs.get(
|
||||
"return_dict_in_generate", config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# --- Initialization ---
|
||||
mask_token_id = self.tokenizer.mask_token_id
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = [
|
||||
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
|
||||
for p in inputs
|
||||
]
|
||||
|
||||
B = len(inputs)
|
||||
seq_lens = [t.shape[0] for t in inputs]
|
||||
T = max(seq_lens)
|
||||
|
||||
# Build right-aligned canvas; left side padded with EOS (used as pad)
|
||||
x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device)
|
||||
for i, t in enumerate(inputs):
|
||||
L = seq_lens[i]
|
||||
x[i, -L:] = t
|
||||
|
||||
# Build 1D attention mask (valid tokens on the right)
|
||||
attention_mask = torch.zeros((B, T), dtype=torch.bool, device=self.model.device)
|
||||
for j, L in enumerate(seq_lens):
|
||||
if L > 0:
|
||||
attention_mask[j, -L:] = True
|
||||
|
||||
# Expand to pairwise attention if left padding is present
|
||||
if torch.any(attention_mask == 0.0):
|
||||
pos_id = attention_mask.long().cumsum(-1) - 1
|
||||
pos_id.masked_fill_(attention_mask == 0, 1)
|
||||
else:
|
||||
pos_id = None
|
||||
attention_mask = "full"
|
||||
|
||||
# Precompute per-sample transfer schedule (how many to commit per step)
|
||||
mask_index = x == mask_token_id
|
||||
num_transfer_tokens_list = get_num_transfer_tokens(
|
||||
mask_index=mask_index,
|
||||
steps=steps,
|
||||
scheduler=self.scheduler,
|
||||
stochastic=stochastic_transfer,
|
||||
)
|
||||
effective_steps = num_transfer_tokens_list.size(1)
|
||||
|
||||
# Optional initial token hook
|
||||
x = generation_tokens_hook_func(None, x, None)
|
||||
histories = [x.clone()] if return_dict_in_generate else None
|
||||
for i in range(effective_steps):
|
||||
mask_index = x == mask_token_id
|
||||
|
||||
# Forward pass, then AR-shift to predict token at position i+1
|
||||
logits = self.model(x, attention_mask, pos_id).logits
|
||||
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
logits = generation_logits_hook_func(i, x, logits)
|
||||
|
||||
# Logits restricted to current `[MASK]` positions
|
||||
mask_logits = logits[mask_index]
|
||||
|
||||
# Confidence scoring for masked positions
|
||||
if alg == "maskgit_plus":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
|
||||
)
|
||||
elif alg == "topk_margin":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
margin_confidence=True,
|
||||
)
|
||||
elif alg == "entropy":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
neg_entropy=True,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown alg: {alg}")
|
||||
|
||||
# Scatter per-position confidence back to full canvas
|
||||
full_confidence = torch.full_like(
|
||||
x, -torch.inf, device=self.model.device, dtype=logits.dtype
|
||||
)
|
||||
full_confidence[mask_index] = confidence
|
||||
|
||||
# Commit the scheduled number of tokens per sample
|
||||
for j in range(B):
|
||||
number_transfer_tokens = num_transfer_tokens_list[j, i]
|
||||
if number_transfer_tokens > 0:
|
||||
if alg_temp is None or alg_temp == 0:
|
||||
_, transfer_index = torch.topk(
|
||||
full_confidence[j], number_transfer_tokens
|
||||
)
|
||||
else:
|
||||
fc = full_confidence[j] / alg_temp
|
||||
fc = F.softmax(fc, dim=-1)
|
||||
transfer_index = torch.multinomial(
|
||||
fc, num_samples=number_transfer_tokens
|
||||
)
|
||||
|
||||
# Candidate tokens at masked positions only
|
||||
x_ = torch.full_like(x, mask_token_id, device=self.model.device)
|
||||
x_[mask_index] = x0.clone()
|
||||
x[j, transfer_index] = x_[j, transfer_index]
|
||||
|
||||
# Optional token hook + history logging
|
||||
x = generation_tokens_hook_func(i, x, logits)
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
if not return_dict_in_generate:
|
||||
return x
|
||||
else:
|
||||
return GeneratorOutput(sequences=x, histories=histories)
|
||||
13
dllm/dllm/pipelines/dream/models/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .configuration_dream import DreamConfig
|
||||
from .modeling_dream import DreamModel
|
||||
|
||||
# Register with HuggingFace Auto classes for local usage
|
||||
try:
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
|
||||
|
||||
AutoConfig.register("Dream", DreamConfig)
|
||||
AutoModel.register(DreamConfig, DreamModel)
|
||||
AutoModelForMaskedLM.register(DreamConfig, DreamModel)
|
||||
except ImportError:
|
||||
# transformers not available or Auto classes not imported
|
||||
pass
|
||||
85
dllm/dllm/pipelines/dream/models/configuration_dream.py
Normal file
@ -0,0 +1,85 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dream model configuration"""
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DreamConfig(PretrainedConfig):
|
||||
model_type = "Dream"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=4096,
|
||||
intermediate_size=22016,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=False, # cache not used in diffusion
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
mask_token_id=151666,
|
||||
pad_token_id=151643,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.mask_token_id = mask_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
465
dllm/dllm/pipelines/dream/models/generation_utils.py
Normal file
@ -0,0 +1,465 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributions as dists
|
||||
from torch.nn import functional as F
|
||||
from transformers import __version__
|
||||
from transformers.generation.configuration_utils import (
|
||||
GenerationConfig
|
||||
)
|
||||
from transformers.utils import (
|
||||
ModelOutput,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def top_p_logits(logits, top_p=None):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
||||
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
||||
return logits
|
||||
|
||||
def top_k_logits(logits, top_k=None):
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
||||
return logits
|
||||
|
||||
|
||||
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
||||
|
||||
if temperature > 0:
|
||||
logits = logits / temperature
|
||||
if top_p is not None and top_p < 1:
|
||||
logits = top_p_logits(logits, top_p)
|
||||
if top_k is not None:
|
||||
logits = top_k_logits(logits, top_k)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if temperature > 0:
|
||||
try:
|
||||
x0 = dists.Categorical(probs=probs).sample()
|
||||
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
||||
except:
|
||||
confidence, x0 = probs.max(dim=-1)
|
||||
else:
|
||||
confidence, x0 = probs.max(dim=-1)
|
||||
|
||||
if margin_confidence:
|
||||
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
||||
# Extract top1 and top2 probabilities
|
||||
top1_probs = sorted_probs[:, 0]
|
||||
top2_probs = sorted_probs[:, 1]
|
||||
# Calculate confidence as top1 - top2
|
||||
confidence = top1_probs - top2_probs
|
||||
|
||||
if neg_entropy:
|
||||
epsilon = 1e-10
|
||||
log_probs = torch.log(probs + epsilon)
|
||||
confidence = torch.sum(probs * log_probs, dim=-1)
|
||||
|
||||
return confidence, x0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamModelOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
history: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class DreamGenerationConfig(GenerationConfig):
|
||||
def __init__(self, **kwargs):
|
||||
self.temperature: float = kwargs.pop("temperature", 0.0)
|
||||
self.top_p: Optional[float] = kwargs.pop("top_p", None)
|
||||
self.top_k: Optional[int] = kwargs.pop("top_k", None)
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
||||
# diffusion specific params
|
||||
self.eps: float = kwargs.pop("eps", 1e-3)
|
||||
self.steps: int = kwargs.pop("steps", 512)
|
||||
self.alg: str = kwargs.pop("alg", 'origin')
|
||||
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
||||
|
||||
# Parameters that define the output variables of `generate`
|
||||
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
||||
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
||||
self.output_history: bool = kwargs.pop("output_history", False)
|
||||
|
||||
# Special tokens that can be used at generation time
|
||||
self.mask_token_id = kwargs.pop("mask_token_id", None)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
||||
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
|
||||
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
||||
# interface.
|
||||
self._from_model_config = kwargs.pop("_from_model_config", False)
|
||||
self._commit_hash = kwargs.pop("_commit_hash", None)
|
||||
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
||||
|
||||
# Additional attributes without default values
|
||||
if not self._from_model_config:
|
||||
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
|
||||
# model's default configuration file
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except AttributeError as err:
|
||||
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||
raise err
|
||||
|
||||
# Validate the values of the attributes
|
||||
self.validate(is_init=True)
|
||||
|
||||
def validate(self, is_init=False, **kwargs):
|
||||
pass
|
||||
|
||||
class DreamGenerationMixin:
|
||||
@staticmethod
|
||||
def _expand_inputs_for_generation(
|
||||
expand_size: int = 1,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None
|
||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
||||
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
|
||||
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
|
||||
# the input tensor and thus requires more memory although no change is applied
|
||||
if expand_size == 1:
|
||||
return input_ids, attention_mask
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
||||
return input_ids, attention_mask
|
||||
|
||||
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
||||
"""Performs validation related to the resulting generated length"""
|
||||
|
||||
# Can't throw warnings/exceptions during compilation
|
||||
if is_torchdynamo_compiling():
|
||||
return
|
||||
|
||||
# 1. Max length warnings related to poor parameterization
|
||||
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
||||
# 20 is the default max_length of the generation config
|
||||
warnings.warn(
|
||||
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
|
||||
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
|
||||
"generation.",
|
||||
UserWarning,
|
||||
)
|
||||
if input_ids_length >= generation_config.max_length:
|
||||
input_ids_string = "input_ids"
|
||||
raise ValueError(
|
||||
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_length` or, better yet, setting `max_new_tokens`."
|
||||
)
|
||||
|
||||
def _prepare_generated_length(
|
||||
self,
|
||||
generation_config,
|
||||
has_default_max_length,
|
||||
input_ids_length,
|
||||
):
|
||||
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
|
||||
|
||||
if generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length and generation_config.max_length is not None:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
|
||||
elif has_default_max_length:
|
||||
if generation_config.max_length == DreamGenerationConfig().max_length:
|
||||
generation_config.max_length = generation_config.max_length + input_ids_length
|
||||
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
||||
if max_position_embeddings is not None:
|
||||
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
||||
|
||||
return generation_config
|
||||
|
||||
def _prepare_generation_config(
|
||||
self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
|
||||
) -> DreamGenerationConfig:
|
||||
"""
|
||||
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
||||
function handles retrocompatibility with respect to configuration files.
|
||||
"""
|
||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||
using_model_generation_config = False
|
||||
if generation_config is None:
|
||||
generation_config = DreamGenerationConfig.from_model_config(self.config)
|
||||
using_model_generation_config = True
|
||||
|
||||
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
|
||||
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
|
||||
# exception will be raised in `_validate_model_kwargs`
|
||||
if not is_torchdynamo_compiling():
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
_kwargs = generation_config.update(**kwargs)
|
||||
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
|
||||
if not using_model_generation_config:
|
||||
if generation_config.bos_token_id is None:
|
||||
generation_config.bos_token_id = self.generation_config.bos_token_id
|
||||
if generation_config.eos_token_id is None:
|
||||
generation_config.eos_token_id = self.generation_config.eos_token_id
|
||||
if generation_config.pad_token_id is None:
|
||||
generation_config.pad_token_id = self.generation_config.pad_token_id
|
||||
if generation_config.mask_token_id is None:
|
||||
generation_config.mask_token_id = self.generation_config.mask_token_id
|
||||
|
||||
return generation_config
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self,
|
||||
generation_config: DreamGenerationConfig,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
):
|
||||
"""
|
||||
Prepares the special tokens for generation, overwriting the generation config with their processed versions
|
||||
converted to tensor.
|
||||
|
||||
Note that `generation_config` is changed in place and stops being serializable after this method is called.
|
||||
That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
|
||||
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
||||
"""
|
||||
|
||||
# Convert special tokens to tensors
|
||||
def _tensor_or_none(token, device=None):
|
||||
if token is None:
|
||||
return token
|
||||
|
||||
device = device if device is not None else self.device
|
||||
if isinstance(token, torch.Tensor):
|
||||
return token.to(device)
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
|
||||
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
|
||||
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
|
||||
mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
|
||||
|
||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
|
||||
eos_token_tensor = eos_token_tensor.unsqueeze(0)
|
||||
|
||||
# Set pad token if unset (and there are conditions to do so)
|
||||
if pad_token_tensor is None and eos_token_tensor is not None:
|
||||
pad_token_tensor = eos_token_tensor[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
|
||||
|
||||
# Update generation config with the updated special tokens tensors
|
||||
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
|
||||
# (in their non-tensor form), in order to enable end-to-end compilation. See
|
||||
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
|
||||
generation_config._bos_token_tensor = bos_token_tensor
|
||||
generation_config._eos_token_tensor = eos_token_tensor
|
||||
generation_config._pad_token_tensor = pad_token_tensor
|
||||
generation_config._mask_token_tensor = mask_token_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def diffusion_generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[DreamGenerationConfig] = None,
|
||||
**kwargs,
|
||||
) -> Union[DreamModelOutput, torch.LongTensor]:
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
||||
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|
||||
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
|
||||
|
||||
# 2. Define model inputs
|
||||
assert inputs is not None
|
||||
input_ids = inputs
|
||||
device = input_ids.device
|
||||
attention_mask = kwargs.pop("attention_mask", None)
|
||||
self._prepare_special_tokens(generation_config, device=device)
|
||||
|
||||
# 3. Prepare `max_length`.
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 4. Check input_ids
|
||||
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
|
||||
warnings.warn(
|
||||
"You are calling .generate() with the `input_ids` being on a device type different"
|
||||
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
|
||||
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
|
||||
" Please make sure that you have put `input_ids` to the"
|
||||
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
||||
" running `.generate()`.",
|
||||
UserWarning,
|
||||
)
|
||||
if (
|
||||
hasattr(generation_config, "pad_token_id") and
|
||||
torch.any(input_ids == generation_config.pad_token_id) and
|
||||
attention_mask is None
|
||||
):
|
||||
warnings.warn(
|
||||
"Padding was detected but no attention mask is passed here. For correct "
|
||||
"generation results, please set `attention_mask` when batch-padding inputs.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
input_ids, attention_mask = self._expand_inputs_for_generation(
|
||||
expand_size=generation_config.num_return_sequences,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask
|
||||
)
|
||||
|
||||
result = self._sample(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=generation_config,
|
||||
generation_tokens_hook_func=generation_tokens_hook_func,
|
||||
generation_logits_hook_func=generation_logits_hook_func
|
||||
)
|
||||
return result
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor],
|
||||
generation_config: DreamGenerationConfig,
|
||||
generation_tokens_hook_func,
|
||||
generation_logits_hook_func
|
||||
) -> Union[DreamModelOutput, torch.LongTensor]:
|
||||
# init values
|
||||
|
||||
output_history = generation_config.output_history
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
max_length = generation_config.max_length
|
||||
mask_token_id = generation_config.mask_token_id
|
||||
steps = generation_config.steps
|
||||
eps = generation_config.eps
|
||||
alg = generation_config.alg
|
||||
alg_temp = generation_config.alg_temp
|
||||
temperature = generation_config.temperature
|
||||
top_p = generation_config.top_p
|
||||
top_k = generation_config.top_k
|
||||
|
||||
histories = [] if (return_dict_in_generate and output_history) else None
|
||||
|
||||
# pad input_ids to max_length
|
||||
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
|
||||
|
||||
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
||||
# we do not mask the [MASK] tokens so value = 1.0
|
||||
attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
|
||||
tok_idx = attention_mask.long().cumsum(-1) - 1
|
||||
tok_idx.masked_fill_(attention_mask == 0, 1)
|
||||
# attention_mask is of shape [B, N]
|
||||
# broadcast to [B, 1, N, N]
|
||||
attention_mask = torch.logical_and(
|
||||
attention_mask.unsqueeze(1).unsqueeze(-2),
|
||||
attention_mask.unsqueeze(1).unsqueeze(-1),
|
||||
)
|
||||
else:
|
||||
tok_idx = None
|
||||
attention_mask = "full"
|
||||
|
||||
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
||||
|
||||
# this allows user-defined token control of the intermediate steps
|
||||
x = generation_tokens_hook_func(None, x, None)
|
||||
for i in range(steps):
|
||||
|
||||
mask_index = (x == mask_token_id)
|
||||
logits = self(x, attention_mask, tok_idx).logits
|
||||
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
||||
# this allows user-defined logits control of the intermediate steps
|
||||
logits = generation_logits_hook_func(i, x, logits)
|
||||
|
||||
mask_logits = logits[mask_index]
|
||||
t = timesteps[i]
|
||||
s = timesteps[i + 1]
|
||||
|
||||
if alg == 'origin':
|
||||
p_transfer = 1 - s / t if i < steps - 1 else 1
|
||||
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
||||
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
||||
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
x[mask_index] = x0.clone()
|
||||
else:
|
||||
if alg == 'maskgit_plus':
|
||||
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
elif alg == 'topk_margin':
|
||||
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
||||
elif alg == 'entropy':
|
||||
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown alg: {alg}")
|
||||
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
||||
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
||||
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
||||
full_confidence[mask_index] = confidence
|
||||
if number_transfer_tokens > 0:
|
||||
if alg_temp is None or alg_temp == 0:
|
||||
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
||||
else:
|
||||
full_confidence = full_confidence / alg_temp
|
||||
full_confidence = F.softmax(full_confidence, dim=-1)
|
||||
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
||||
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
||||
x_[mask_index] = x0.clone()
|
||||
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
||||
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
|
||||
|
||||
# this allows user-defined token control of the intermediate steps
|
||||
x = generation_tokens_hook_func(i, x, logits)
|
||||
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
if return_dict_in_generate:
|
||||
return DreamModelOutput(
|
||||
sequences=x,
|
||||
history=histories,
|
||||
)
|
||||
else:
|
||||
return x
|
||||
850
dllm/dllm/pipelines/dream/models/modeling_dream.py
Normal file
@ -0,0 +1,850 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT and Qwen implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Dream model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
)
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
)
|
||||
from transformers import PretrainedConfig
|
||||
from .configuration_dream import DreamConfig
|
||||
from .generation_utils import DreamGenerationMixin, DreamGenerationConfig
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "Dream-7B"
|
||||
_CONFIG_FOR_DOC = "DreamConfig"
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream
|
||||
class DreamRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
DreamRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream
|
||||
class DreamRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[DreamConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
def reset_parameters(self):
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream
|
||||
class DreamMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class DreamAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
and "Generating Long Sequences with Sparse Transformers".
|
||||
"""
|
||||
|
||||
def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = False
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.rotary_emb = DreamRotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class DreamSdpaAttention(DreamAttention):
|
||||
"""
|
||||
Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from DreamAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# causal_mask = attention_mask
|
||||
# if attention_mask is not None: # no matter the length, we just slice it
|
||||
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
# is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=False, # hard coded
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class DreamDecoderLayer(nn.Module):
|
||||
def __init__(self, config: DreamConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_once(
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
|
||||
# self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
self.self_attn = DreamSdpaAttention(config, layer_idx)
|
||||
|
||||
self.mlp = DreamMLP(config)
|
||||
self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
class DreamPreTrainedModel(PreTrainedModel):
|
||||
config_class = DreamConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DreamDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
*model_args,
|
||||
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
force_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
use_safetensors: Optional[bool] = None,
|
||||
weights_only: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
_model = super().from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
*model_args,
|
||||
config=config,
|
||||
cache_dir=cache_dir,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
use_safetensors=use_safetensors,
|
||||
weights_only=weights_only,
|
||||
**kwargs,
|
||||
)
|
||||
# NOTE(Lin): we need to override the generation config
|
||||
# because the generation config loaded in `from_pretrained`
|
||||
# does not include all the attributes of DreamGenerationConfig
|
||||
resume_download = kwargs.get("resume_download", None)
|
||||
proxies = kwargs.get("proxies", None)
|
||||
subfolder = kwargs.get("subfolder", "")
|
||||
from_auto_class = kwargs.get("_from_auto", False)
|
||||
from_pipeline = kwargs.get("_from_pipeline", None)
|
||||
_model.generation_config = DreamGenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
)
|
||||
return _model
|
||||
|
||||
class DreamBaseModel(DreamPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: DreamConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: DreamConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = DreamRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = DreamBaseModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def reset_rope_parameters(self):
|
||||
self.model.rotary_emb.reset_parameters()
|
||||
for layer in self.model.layers:
|
||||
layer.self_attn.rotary_emb.reset_parameters()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, MaskedLMOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if isinstance(attention_mask, str) and attention_mask == "full" or attention_mask == None:
|
||||
# whether attention_mask is full
|
||||
pass
|
||||
|
||||
elif isinstance(attention_mask, torch.Tensor):
|
||||
if not torch.any(attention_mask == 0.0):
|
||||
attention_mask = 'full'
|
||||
elif attention_mask.dim() == 2:
|
||||
# [B, L] → [B, 1, L, L]
|
||||
attention_mask = torch.logical_and(
|
||||
attention_mask.unsqueeze(1).unsqueeze(-2),
|
||||
attention_mask.unsqueeze(1).unsqueeze(-1),
|
||||
)
|
||||
attention_mask = attention_mask.to(torch.bool)
|
||||
|
||||
elif attention_mask.dim() in (3, 4):
|
||||
# already extended/broadcasted form
|
||||
if attention_mask.dtype != torch.bool:
|
||||
attention_mask = attention_mask.to(torch.bool)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected attention_mask shape: {attention_mask.shape}")
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unsupported attention_mask type: {type(attention_mask)}")
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
346
dllm/dllm/pipelines/dream/models/tokenization_dream.py
Normal file
@ -0,0 +1,346 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Dream team, HKUNLP Group and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on Qwen's implementations in this library.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for Dream."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import unicodedata
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import regex as re
|
||||
|
||||
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"merges_file": "merges.txt",
|
||||
}
|
||||
|
||||
|
||||
MAX_MODEL_INPUT_SIZES = {"dream/dream-tokenizer": 32768}
|
||||
|
||||
PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||
|
||||
|
||||
@lru_cache()
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
||||
characters the bpe code barfs on.
|
||||
|
||||
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
||||
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
||||
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
||||
tables between utf-8 bytes and unicode strings.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class DreamTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a Dream tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
|
||||
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Dream-org/Dream-v0-Base-7B", trust_remote_code=True)
|
||||
>>> tokenizer("Hello world")["input_ids"]
|
||||
[9707, 1879]
|
||||
|
||||
>>> tokenizer(" Hello world")["input_ids"]
|
||||
[21927, 1879]
|
||||
```
|
||||
This is expected.
|
||||
|
||||
You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||
this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
Path to the merges file.
|
||||
errors (`str`, *optional*, defaults to `"replace"`):
|
||||
Paradigm to follow when decoding bytes to UTF-8. See
|
||||
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token. Not applicable for this tokenizer.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
|
||||
tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
|
||||
split_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the special tokens should be split during the tokenization process. The default behavior is
|
||||
to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
|
||||
['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
|
||||
'|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
errors="replace",
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token=None,
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>",
|
||||
clean_up_tokenization_spaces=False,
|
||||
split_special_tokens=False,
|
||||
**kwargs,
|
||||
):
|
||||
# Dream vocab does not contain control tokens; added tokens need to be special
|
||||
bos_token = (
|
||||
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(bos_token, str)
|
||||
else bos_token
|
||||
)
|
||||
eos_token = (
|
||||
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(eos_token, str)
|
||||
else eos_token
|
||||
)
|
||||
unk_token = (
|
||||
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(unk_token, str)
|
||||
else unk_token
|
||||
)
|
||||
pad_token = (
|
||||
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(pad_token, str)
|
||||
else pad_token
|
||||
)
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
bpe_merges = []
|
||||
with open(merges_file, encoding="utf-8") as merges_handle:
|
||||
for i, line in enumerate(merges_handle):
|
||||
line = line.strip()
|
||||
if (i == 0 and line.startswith("#version:")) or not line:
|
||||
continue
|
||||
bpe_merges.append(tuple(line.split()))
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
# NOTE: the cache can grow without bound and will get really large for long running processes
|
||||
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
|
||||
# not a memory leak but appears as one.
|
||||
# GPT2Tokenizer has the same problem, so let's be consistent.
|
||||
self.cache = {}
|
||||
|
||||
self.pat = re.compile(PRETOKENIZE_REGEX)
|
||||
|
||||
if kwargs.get("add_prefix_space", False):
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
errors=errors,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
unk_token=unk_token,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
split_special_tokens=split_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self.encoder)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
|
||||
def get_vocab(self):
|
||||
return dict(self.encoder, **self.added_tokens_encoder)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
except ValueError:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
else:
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
|
||||
def _tokenize(self, text):
|
||||
"""Tokenize a string."""
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = "".join(
|
||||
self.byte_encoder[b] for b in token.encode("utf-8")
|
||||
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.decoder.get(index)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
text = "".join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
||||
return text
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = False,
|
||||
spaces_between_special_tokens: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
|
||||
# and cannot be configured elsewhere, but it should default to False for DreamTokenizer
|
||||
return super().decode(
|
||||
token_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
merge_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
||||
)
|
||||
|
||||
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write("#version: 0.2\n")
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!"
|
||||
)
|
||||
index = token_index
|
||||
writer.write(" ".join(bpe_tokens) + "\n")
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
||||
|
||||
def prepare_for_tokenization(self, text, **kwargs):
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
return (text, kwargs)
|
||||
|
||||
|
||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
|
||||
from .configuration_dream import DreamConfig
|
||||
|
||||
TOKENIZER_MAPPING.register(DreamConfig, (DreamTokenizer, None))
|
||||
84
dllm/dllm/pipelines/dream/trainer.py
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from dllm.core.trainers import MDLMTrainer
|
||||
|
||||
|
||||
def cart_weight(
|
||||
masked_indices: torch.Tensor, t: torch.Tensor, p: float = 0.3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Optimized CART weight computation using matrix operations.
|
||||
|
||||
Args:
|
||||
masked_indices (torch.Tensor): (b, l) bool tensor indicating masked positions.
|
||||
t (torch.Tensor): (b,) time steps (0-1 sampled uniformly). Not directly used in CART.
|
||||
p (float): Parameter of geometric distribution (0 < p <= 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (b, l) float tensor of weights.
|
||||
"""
|
||||
b, l = masked_indices.shape
|
||||
device = masked_indices.device
|
||||
|
||||
idx = torch.arange(l, device=device)
|
||||
dist_matrix = (idx[None, :] - idx[:, None]).abs() - 1
|
||||
dist_matrix = torch.clamp(dist_matrix, min=0) # (l, l)
|
||||
geo_matrix = (
|
||||
torch.log(torch.tensor(p, device=device))
|
||||
+ (dist_matrix - 1).clamp(min=0) * torch.log(torch.tensor(1 - p, device=device))
|
||||
).exp() * 0.5 # Ensure numerical stability
|
||||
geo_matrix.masked_fill_(dist_matrix == 0, 0.0) # ignore distance = 0
|
||||
|
||||
valid_mask = (~masked_indices).float() # (b, l), 1 = unmasked
|
||||
weights = valid_mask @ geo_matrix.T # (b, l)
|
||||
weights = weights * masked_indices.float()
|
||||
return weights
|
||||
|
||||
|
||||
class DreamTrainer(MDLMTrainer):
|
||||
"""
|
||||
DreamTrainer: specialization of MDLMTrainer for Dream training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
loss_weight_type: str = "cart[geo_p:0.3]",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, loss_weight_type=loss_weight_type, **kwargs)
|
||||
|
||||
def _preprocess_inputs(self, inputs):
|
||||
labels = inputs["labels"]
|
||||
assert (labels[:, 0] == -100).all()
|
||||
|
||||
def _postprocess_outputs(self, outputs):
|
||||
logits = outputs.logits
|
||||
outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
|
||||
def _compute_loss_weights(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
inputs: dict[str, Any],
|
||||
masked_indices: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if self.loss_weight_type.startswith("cart"):
|
||||
# parse geo_p
|
||||
import re
|
||||
|
||||
match = re.search(r"geo_p:(0\.\d+)", self.loss_weight_type)
|
||||
geo_p = float(match.group(1)) if match else 0.3
|
||||
loss_weights = cart_weight(masked_indices, t, p=geo_p)
|
||||
else:
|
||||
loss_weights = super()._compute_loss_weights(
|
||||
t=t,
|
||||
inputs=inputs,
|
||||
masked_indices=masked_indices,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
return loss_weights
|
||||
180
dllm/dllm/pipelines/dream/utils.py
Normal file
@ -0,0 +1,180 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
|
||||
|
||||
def top_p_logits(logits, top_p=None):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
||||
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
||||
return logits
|
||||
|
||||
|
||||
def top_k_logits(logits, top_k=None):
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
||||
return logits
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamSFTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Randomly crop response length to reduce length bias during generation.
|
||||
|
||||
Reference: https://github.com/DreamLM/Dream/blob/main/src/trainer/fsdp_sft_trainer.py
|
||||
"""
|
||||
|
||||
perbatch_cutoff: bool = True # Use prebatch truncation if True
|
||||
resp_cutoff_ratio: float = 0.0 # Prob. of post-collation truncation
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 1) Pre-collation truncation (per-sample)
|
||||
# -------------------------------------------------------------------------
|
||||
def apply_perbatch_cutoff(self, features):
|
||||
"""
|
||||
Randomly pick a response length from batch (`kept_len`) and trim other responses.
|
||||
Before:
|
||||
[<--promptA----><------responseA------>]
|
||||
[<--promptB-><---responseB--->]
|
||||
[<---promptC----><--respC-->]
|
||||
After:
|
||||
[<--promptA----><---respA--->]
|
||||
[<--promptB-><--respB-->]
|
||||
[<---promptC----><--respC-->]
|
||||
kept_len = 10 → trim each response to ≤10 tokens (before padding)
|
||||
"""
|
||||
resp_lens = torch.tensor(
|
||||
[len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long
|
||||
)
|
||||
kept_len = int(np.random.choice(resp_lens))
|
||||
for f, r_len in zip(features, resp_lens):
|
||||
remove_len = max(r_len - kept_len, 0)
|
||||
if remove_len > 0:
|
||||
# f["input_ids"] = f["input_ids"][:-remove_len]
|
||||
# f["attention_mask"] = f["attention_mask"][:-remove_len]
|
||||
# f["labels"] = f["labels"][:-remove_len]
|
||||
for key in ["input_ids", "labels", "attention_mask"]:
|
||||
if key in f:
|
||||
f[key] = f[key][:-remove_len]
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 2) Post-collation truncation
|
||||
# -------------------------------------------------------------------------
|
||||
def apply_resp_cutoff(self, batch, features):
|
||||
"""
|
||||
Uniformly chop tail *after padding*. All sequences truncated to new_seq_len.
|
||||
Before:
|
||||
[<--promptA----><-----respA----->] 40
|
||||
[<--promptB-><respB><----pad---->] 40
|
||||
[<---promptC----><--respC--><pad>] 40
|
||||
cutoff_len = 5
|
||||
After:
|
||||
[<--promptA----><--respA--->] 35
|
||||
[<--promptB-><respB><--pad->] 35
|
||||
[<---promptC----><--respC-->] 35
|
||||
"""
|
||||
orig_seq_lens = [len(f["input_ids"]) for f in features]
|
||||
resp_lens = torch.tensor(
|
||||
[len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long
|
||||
)
|
||||
min_resp_len = resp_lens.min().item()
|
||||
if min_resp_len <= 1:
|
||||
return batch
|
||||
|
||||
cutoff_len = int(np.random.randint(1, min_resp_len))
|
||||
new_seq_len = max(orig_seq_lens) - cutoff_len
|
||||
|
||||
for key in ["input_ids", "labels", "attention_mask"]:
|
||||
if key in batch:
|
||||
batch[key] = batch[key][:, :new_seq_len].contiguous()
|
||||
return batch
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 3) Main call: pick truncation mode
|
||||
# -------------------------------------------------------------------------
|
||||
def __call__(self, features, return_tensors=None):
|
||||
# optional pre-collation truncation
|
||||
if self.perbatch_cutoff:
|
||||
features = self.apply_perbatch_cutoff(features)
|
||||
|
||||
# always collate only the needed fields
|
||||
base = [
|
||||
{k: f[k] for k in ("input_ids", "labels", "attention_mask") if k in f}
|
||||
for f in features
|
||||
]
|
||||
batch = super().__call__(base, return_tensors=return_tensors)
|
||||
|
||||
# optional post-collation truncation
|
||||
if (
|
||||
not self.perbatch_cutoff
|
||||
and self.resp_cutoff_ratio > 0
|
||||
and np.random.rand() < self.resp_cutoff_ratio
|
||||
):
|
||||
batch = self.apply_resp_cutoff(batch, features)
|
||||
|
||||
batch.pop("prompt_len", None)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamPTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
random_length_ratio: float = 0.01
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors=return_tensors)
|
||||
input_ids, labels, attention_mask = (
|
||||
outputs["input_ids"],
|
||||
outputs["labels"],
|
||||
outputs["attention_mask"],
|
||||
)
|
||||
bsz, seq_len = input_ids.shape
|
||||
|
||||
# --- Random truncation for robustness ---
|
||||
if torch.rand(1).item() < self.random_length_ratio:
|
||||
random_len = torch.randint(1, seq_len + 1, (1,)).item()
|
||||
input_ids = input_ids[:, :random_len]
|
||||
labels = labels[:, :random_len]
|
||||
attention_mask = attention_mask[:, :random_len]
|
||||
|
||||
# --- Add BOS token to the beginning of input_ids ---
|
||||
bos = torch.full(
|
||||
(bsz, 1),
|
||||
self.tokenizer.bos_token_id,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
input_ids = torch.cat([bos, input_ids], dim=1)
|
||||
|
||||
# --- Prepend zeros to labels instead of BOS ---
|
||||
ignore_labels = self.label_pad_token_id * torch.ones(
|
||||
(bsz, 1), dtype=labels.dtype, device=labels.device
|
||||
)
|
||||
labels = torch.cat([ignore_labels, labels], dim=1)
|
||||
|
||||
# --- Prepend ones to attention_mask ---
|
||||
bos_attention = torch.ones(
|
||||
(bsz, 1), dtype=attention_mask.dtype, device=attention_mask.device
|
||||
)
|
||||
attention_mask = torch.cat([bos_attention, attention_mask], dim=1)
|
||||
|
||||
# --- Update and return ---
|
||||
outputs["input_ids"] = input_ids
|
||||
outputs["labels"] = labels
|
||||
outputs["attention_mask"] = attention_mask
|
||||
# Check if attention_mask is all ones and set it to None
|
||||
if torch.all(outputs["attention_mask"] == 1):
|
||||
outputs.pop("attention_mask")
|
||||
return outputs
|
||||
14
dllm/dllm/pipelines/editflow/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from . import trainer, utils
|
||||
from .models.dream.modelling_dream import (
|
||||
EditFlowDreamConfig,
|
||||
EditFlowDreamModel,
|
||||
)
|
||||
from .models.llada.modelling_llada import (
|
||||
EditFlowLLaDAConfig,
|
||||
EditFlowLLaDAModel,
|
||||
)
|
||||
from .models.bert.modelling_modernbert import (
|
||||
EditFlowModernBertConfig,
|
||||
EditFlowModernBertModel,
|
||||
)
|
||||
from dllm.pipelines.editflow.trainer import EditFlowTrainer
|
||||
@ -0,0 +1,89 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
class EditFlowModernBertConfig(transformers.ModernBertConfig):
|
||||
model_type = "editflow-modernbert" # <- NEW model_type
|
||||
|
||||
|
||||
class EditFlowModernBertModel(transformers.ModernBertForMaskedLM):
|
||||
config_class = EditFlowModernBertConfig
|
||||
modules_to_save = {
|
||||
"rate_heads",
|
||||
"sub_logits",
|
||||
"ins_logits",
|
||||
} # fully fintuned even using lora
|
||||
|
||||
def __init__(self, config):
|
||||
# fa2 has bugs when forward(output_hidden_states=True)
|
||||
config._attn_implementation = "sdpa"
|
||||
super().__init__(config)
|
||||
in_lm, out_lm = self.decoder.in_features, self.decoder.out_features
|
||||
use_bias = self.decoder.bias is not None
|
||||
# Create new, independent heads (no deepcopy)
|
||||
self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
|
||||
self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
|
||||
self.rate_heads = nn.Sequential(nn.Linear(in_lm, 3), nn.Softplus())
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
t: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
output = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
)
|
||||
h = output["hidden_states"][-1] # final hidden states
|
||||
h = self.head(h)
|
||||
# Position heads
|
||||
sub_log = self.sub_logits(h) # [B, L, V]
|
||||
ins_log = self.ins_logits(h) # [B, L, V]
|
||||
|
||||
rates = self.rate_heads(h)
|
||||
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
|
||||
-1
|
||||
) # [B, L], [B, L], [B, L]
|
||||
return dict(
|
||||
sub_rate_hat=sub_rate_hat, # [B,L]
|
||||
del_rate_hat=del_rate_hat, # [B,L]
|
||||
ins_rate_hat=ins_rate_hat, # [B,L]
|
||||
ins_logits=ins_log, # [B,L,V]
|
||||
sub_logits=sub_log, # [B,L,V]
|
||||
)
|
||||
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoConfig
|
||||
|
||||
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
||||
AutoConfig.register("editflow-modernbert", EditFlowModernBertConfig)
|
||||
AutoModel.register(EditFlowModernBertConfig, EditFlowModernBertModel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dllm
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
# Load a config from a local path (either a directory containing config.json, or the file itself)
|
||||
config_path = dllm.utils.resolve_with_base_env(
|
||||
"answerdotai/ModernBERT-base", "BASE_MODELS_DIR"
|
||||
)
|
||||
config = EditFlowModernBertConfig.from_pretrained(config_path)
|
||||
if hasattr(config, "auto_map"):
|
||||
delattr(config, "auto_map")
|
||||
if hasattr(config, "architectures"):
|
||||
delattr(config, "architectures")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
model = EditFlowModernBertModel(config)
|
||||
model.save_pretrained("models-tmp/editflow-modernbert")
|
||||
auto_model = AutoModel.from_pretrained("models-tmp/editflow-modernbert")
|
||||
97
dllm/dllm/pipelines/editflow/models/dream/modelling_dream.py
Normal file
@ -0,0 +1,97 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dllm.pipelines import dream
|
||||
|
||||
|
||||
class EditFlowDreamConfig(dream.DreamConfig):
|
||||
model_type = "editflow-dream" # <- NEW model_type
|
||||
|
||||
|
||||
class EditFlowDreamModel(dream.DreamModel):
|
||||
config_class = EditFlowDreamConfig
|
||||
modules_to_save = {
|
||||
"rate_heads",
|
||||
"sub_logits",
|
||||
"ins_logits",
|
||||
} # fully fintuned even using lora
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
in_lm, out_lm = self.lm_head.in_features, self.lm_head.out_features
|
||||
use_bias = self.lm_head.bias is not None
|
||||
# Create new, independent heads (no deepcopy)
|
||||
self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
|
||||
self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
|
||||
self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus())
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
t: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
output = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
)
|
||||
h = output["hidden_states"][-1] # final hidden states
|
||||
# Position heads
|
||||
sub_log = self.sub_logits(h) # [B, L, V]
|
||||
sub_log = torch.concatenate(
|
||||
[torch.zeros_like(sub_log)[:, :1], sub_log[:, :-1]], dim=1
|
||||
) # [B, L, V]
|
||||
ins_log = self.ins_logits(h) # [B, L, V]
|
||||
|
||||
rates = self.rate_heads(h)
|
||||
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
|
||||
-1
|
||||
) # [B, L], [B, L], [B, L]
|
||||
sub_rate_hat = torch.concatenate(
|
||||
[torch.zeros_like(sub_rate_hat[:, :1]), sub_rate_hat[:, :-1]], dim=1
|
||||
) # [B, L]
|
||||
del_rate_hat = torch.concatenate(
|
||||
[torch.zeros_like(del_rate_hat[:, :1]), del_rate_hat[:, :-1]], dim=1
|
||||
) # [B, L]
|
||||
return dict(
|
||||
sub_rate_hat=sub_rate_hat, # [B,L]
|
||||
del_rate_hat=del_rate_hat, # [B,L]
|
||||
ins_rate_hat=ins_rate_hat, # [B,L]
|
||||
ins_logits=ins_log, # [B,L,V]
|
||||
sub_logits=sub_log, # [B,L,V]
|
||||
)
|
||||
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoConfig
|
||||
|
||||
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
||||
AutoConfig.register("editflow-dream", EditFlowDreamConfig)
|
||||
AutoModel.register(EditFlowDreamConfig, EditFlowDreamModel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dllm
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
# Load a config from a local path (either a directory containing config.json, or the file itself)
|
||||
config_path = dllm.utils.resolve_with_base_env(
|
||||
"Dream-org/Dream-v0-Base-7B", "BASE_MODELS_DIR"
|
||||
)
|
||||
config = EditFlowDreamConfig.from_pretrained(config_path)
|
||||
if hasattr(config, "auto_map"):
|
||||
delattr(config, "auto_map")
|
||||
if hasattr(config, "architectures"):
|
||||
delattr(config, "architectures")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
model = EditFlowDreamModel(config)
|
||||
model.save_pretrained("models-tmp/editflow-dream")
|
||||
auto_model = AutoModel.from_pretrained("models-tmp/editflow-dream")
|
||||
91
dllm/dllm/pipelines/editflow/models/llada/modelling_llada.py
Normal file
@ -0,0 +1,91 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dllm.pipelines import llada
|
||||
|
||||
|
||||
class EditFlowLLaDAConfig(llada.LLaDAConfig):
|
||||
model_type = "editflow-llada" # <- NEW model_type
|
||||
|
||||
|
||||
class EditFlowLLaDAModel(llada.LLaDAModelLM):
|
||||
config_class = EditFlowLLaDAConfig
|
||||
modules_to_save = {
|
||||
"rate_heads",
|
||||
"sub_logits",
|
||||
"ins_logits",
|
||||
} # fully fintuned even using lora
|
||||
|
||||
def __init__(self, config):
|
||||
# TODO: time embedding
|
||||
super().__init__(config)
|
||||
ff = self.model.transformer.ff_out
|
||||
in_f, out_f = ff.in_features, ff.out_features
|
||||
use_bias = ff.bias is not None
|
||||
# Create new, independent heads (no deepcopy)
|
||||
self.sub_logits = nn.Linear(in_f, out_f, bias=use_bias)
|
||||
self.ins_logits = nn.Linear(in_f, out_f, bias=use_bias)
|
||||
self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus())
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
t: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO: time embedding
|
||||
output = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
)
|
||||
h = output["hidden_states"][-1] # final hidden states
|
||||
# Position heads
|
||||
sub_log = self.sub_logits(h) # [B, L, V]
|
||||
ins_log = self.ins_logits(h) # [B, L, V]
|
||||
|
||||
rates = self.rate_heads(h)
|
||||
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
|
||||
-1
|
||||
) # [B, L], [B, L], [B, L]
|
||||
return dict(
|
||||
sub_rate_hat=sub_rate_hat, # [B,L]
|
||||
del_rate_hat=del_rate_hat, # [B,L]
|
||||
ins_rate_hat=ins_rate_hat, # [B,L]
|
||||
ins_logits=ins_log, # [B,L,V]
|
||||
sub_logits=sub_log, # [B,L,V]
|
||||
)
|
||||
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoConfig
|
||||
|
||||
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
||||
AutoConfig.register("editflow-llada", EditFlowLLaDAConfig)
|
||||
AutoModel.register(EditFlowLLaDAConfig, EditFlowLLaDAModel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dllm
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
# Load a config from a local path (either a directory containing config.json, or the file itself)
|
||||
config_path = dllm.utils.resolve_with_base_env(
|
||||
"GSAI-ML/LLaDA-8B-Base", "BASE_MODELS_DIR"
|
||||
)
|
||||
config = EditFlowLLaDAConfig.from_pretrained(config_path)
|
||||
if hasattr(config, "auto_map"):
|
||||
delattr(config, "auto_map")
|
||||
if hasattr(config, "architectures"):
|
||||
delattr(config, "architectures")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
model = EditFlowLLaDAModel(config)
|
||||
model.save_pretrained("models-tmp/editflow-llada")
|
||||
auto_model = AutoModel.from_pretrained("models-tmp/editflow-llada")
|
||||
407
dllm/dllm/pipelines/editflow/trainer.py
Normal file
@ -0,0 +1,407 @@
|
||||
from typing import Any, Dict, Union, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import transformers
|
||||
|
||||
from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler
|
||||
from dllm.pipelines.editflow.utils import pad_1d
|
||||
|
||||
|
||||
BLANK = -1
|
||||
|
||||
|
||||
def align_with_blanks(
|
||||
x0: List[int], x1: List[int], sub_cost: int = 1, gap_cost: int = 1
|
||||
) -> Dict:
|
||||
"""
|
||||
Needleman–Wunsch global alignment of two integer sequences with:
|
||||
match cost = 0, substitution cost = sub_cost, gap cost = gap_cost.
|
||||
Returns aligned sequences (z0, z1) of equal length containing BLANK = ε where gaps occur.
|
||||
"""
|
||||
n, m = len(x0), len(x1)
|
||||
# DP tables
|
||||
dp = [[0] * (m + 1) for _ in range(n + 1)]
|
||||
ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag', 'up', 'left'
|
||||
|
||||
for i in range(1, n + 1):
|
||||
dp[i][0] = i * gap_cost
|
||||
ptr[i][0] = "up"
|
||||
for j in range(1, m + 1):
|
||||
dp[0][j] = j * gap_cost
|
||||
ptr[0][j] = "left"
|
||||
|
||||
for i in range(1, n + 1):
|
||||
for j in range(1, m + 1):
|
||||
cost_diag = dp[i - 1][j - 1] + (0 if x0[i - 1] == x1[j - 1] else sub_cost)
|
||||
cost_up = dp[i - 1][j] + gap_cost
|
||||
cost_left = dp[i][j - 1] + gap_cost
|
||||
best = min(cost_diag, cost_up, cost_left)
|
||||
dp[i][j] = best
|
||||
if best == cost_diag:
|
||||
ptr[i][j] = "diag"
|
||||
elif best == cost_up:
|
||||
ptr[i][j] = "up"
|
||||
else:
|
||||
ptr[i][j] = "left"
|
||||
|
||||
# traceback
|
||||
z0, z1 = [], []
|
||||
i, j = n, m
|
||||
while i > 0 or j > 0:
|
||||
p = ptr[i][j]
|
||||
if p == "diag":
|
||||
z0.append(x0[i - 1])
|
||||
z1.append(x1[j - 1])
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif p == "up":
|
||||
z0.append(x0[i - 1])
|
||||
z1.append(BLANK)
|
||||
i -= 1
|
||||
else: # 'left'
|
||||
z0.append(BLANK)
|
||||
z1.append(x1[j - 1])
|
||||
j -= 1
|
||||
z0.reverse()
|
||||
z1.reverse()
|
||||
# return Alignment(z0=z0, z1=z1)
|
||||
# return {"z0": z0, "z1": z1}
|
||||
return dict(z0=z0, z1=z1)
|
||||
|
||||
|
||||
# def align_with_blanks(
|
||||
# x0: list[int], x1: list[int], sub_cost: int = 1, gap_cost: int = 1
|
||||
# ) -> dict:
|
||||
# """
|
||||
# Needleman–Wunsch with a secondary objective that defers gaps to the end:
|
||||
# - 'up' (gap in z1) is penalized if j < m
|
||||
# - 'left' (gap in z0) is penalized if i < n
|
||||
# This pushes blanks (-1) to the *right* whether x0 > x1 or x0 < x1.
|
||||
# """
|
||||
# n, m = len(x0), len(x1)
|
||||
|
||||
# dp_cost = [[0] * (m + 1) for _ in range(n + 1)]
|
||||
# dp_pen = [[0] * (m + 1) for _ in range(n + 1)]
|
||||
# ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag' | 'up' | 'left'
|
||||
|
||||
# # Left edge: all 'up' moves with j=0 (< m) → penalize each step
|
||||
# for i in range(1, n + 1):
|
||||
# dp_cost[i][0] = i * gap_cost
|
||||
# dp_pen[i][0] = i # i early 'up' moves
|
||||
# ptr[i][0] = "up"
|
||||
|
||||
# # Top edge: all 'left' moves with i=0 (< n) → penalize each step
|
||||
# for j in range(1, m + 1):
|
||||
# dp_cost[0][j] = j * gap_cost
|
||||
# dp_pen[0][j] = j # j early 'left' moves
|
||||
# ptr[0][j] = "left"
|
||||
|
||||
# for i in range(1, n + 1):
|
||||
# xi = x0[i - 1]
|
||||
# for j in range(1, m + 1):
|
||||
# yj = x1[j - 1]
|
||||
|
||||
# # diag
|
||||
# cost_diag = dp_cost[i - 1][j - 1] + (0 if xi == yj else sub_cost)
|
||||
# pen_diag = dp_pen[i - 1][j - 1]
|
||||
# cand_diag = (cost_diag, pen_diag)
|
||||
|
||||
# # up: add blank to z1, penalize if j < m (early)
|
||||
# cost_up = dp_cost[i - 1][j] + gap_cost
|
||||
# pen_up = dp_pen[i - 1][j] + (1 if j < m else 0)
|
||||
# cand_up = (cost_up, pen_up)
|
||||
|
||||
# # left: add blank to z0, penalize if i < n (early)
|
||||
# cost_left = dp_cost[i][j - 1] + gap_cost
|
||||
# pen_left = dp_pen[i][j - 1] + (1 if i < n else 0)
|
||||
# cand_left = (cost_left, pen_left)
|
||||
|
||||
# # choose (cost,pen) min; deterministic tie-break: diag > left > up
|
||||
# best = min(cand_diag, cand_left, cand_up)
|
||||
# dp_cost[i][j], dp_pen[i][j] = best
|
||||
# if best == cand_diag:
|
||||
# ptr[i][j] = "diag"
|
||||
# elif best == cand_left:
|
||||
# ptr[i][j] = "left"
|
||||
# else:
|
||||
# ptr[i][j] = "up"
|
||||
|
||||
# # traceback
|
||||
# z0, z1 = [], []
|
||||
# i, j = n, m
|
||||
# while i > 0 or j > 0:
|
||||
# p = ptr[i][j]
|
||||
# if p == "diag":
|
||||
# z0.append(x0[i - 1])
|
||||
# z1.append(x1[j - 1])
|
||||
# i -= 1
|
||||
# j -= 1
|
||||
# elif p == "up":
|
||||
# z0.append(x0[i - 1])
|
||||
# z1.append(BLANK)
|
||||
# i -= 1
|
||||
# else: # 'left'
|
||||
# z0.append(BLANK)
|
||||
# z1.append(x1[j - 1])
|
||||
# j -= 1
|
||||
|
||||
# z0.reverse()
|
||||
# z1.reverse()
|
||||
# return dict(z0=z0, z1=z1)
|
||||
|
||||
|
||||
def strip_blanks(z: list[int]) -> list[int]:
|
||||
# IMPORTANT: do NOT strip BOS; we only remove BLANKs
|
||||
return [t for t in z if t != BLANK]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edit:
|
||||
kind: str # "SUB" | "DEL" | "INS"
|
||||
pos: int # position (for SUB/DEL) or token-row idx for INS (incl. BOS row 0)
|
||||
token: int | None # token for SUB/INS, else None
|
||||
|
||||
|
||||
def build_remaining_edits(zt: list[int], z1: list[int]) -> list[Edit]:
|
||||
edits: list[Edit] = []
|
||||
|
||||
def count_nonblank_prefix(z: list[int], j: int) -> int:
|
||||
c = 0
|
||||
for k in range(j):
|
||||
if z[k] != BLANK:
|
||||
c += 1
|
||||
return c
|
||||
|
||||
for j, (a, b) in enumerate(zip(zt, z1)):
|
||||
if a == b:
|
||||
continue
|
||||
nb = count_nonblank_prefix(
|
||||
zt, j
|
||||
) # counts BOS as 1, first content token will be nb=1 before its column
|
||||
|
||||
if a == BLANK and b != BLANK:
|
||||
# INSERT after row (nb-1): BOS insert => nb=1 -> gap=0; general case works too
|
||||
gap = max(nb - 1, 0)
|
||||
edits.append(Edit("INS", gap, b))
|
||||
|
||||
elif a != BLANK and b == BLANK:
|
||||
# DELETE token at row nb (first content token => nb=1, allowed; BOS is never BLANK so nb>=1)
|
||||
pos = nb
|
||||
# if pos > 0: # forbid BOS (row 0)
|
||||
edits.append(Edit("DEL", pos, None))
|
||||
|
||||
else: # a != BLANK, b != BLANK, a != b
|
||||
# SUB token at row nb
|
||||
pos = nb
|
||||
# if pos > 0: # forbid BOS (row 0)
|
||||
edits.append(Edit("SUB", pos, b))
|
||||
return edits
|
||||
|
||||
|
||||
class EditFlowTrainer(transformers.Trainer):
|
||||
"""
|
||||
Trainer for Edit Flows where the model returns:
|
||||
- sub_logits: [B,L,V] (token dist for SUB)
|
||||
- ins_logits: [B,L,V] (token dist for INS)
|
||||
- sub_rate_hat: [B,L] (normalized rates; NO kappa factor)
|
||||
- del_rate_hat: [B,L]
|
||||
- ins_rate_hat: [B,L]
|
||||
True intensities are w * rate_hat, with w = kappa_dot(t) / (1 - kappa(t)).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
scheduler: BaseKappaScheduler | None = None,
|
||||
normalize_per_position: bool = True,
|
||||
time_epsilon: float = 1e-3,
|
||||
max_w: float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.scheduler = scheduler or CubicKappaScheduler()
|
||||
self.normalize_per_position = normalize_per_position
|
||||
self.time_epsilon = time_epsilon
|
||||
self.max_w = max_w
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: transformers.PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
device = self.model.device
|
||||
B = len(inputs["x0_ids"])
|
||||
|
||||
# -------- 1) Align with blanks (z0,z1) and sample time t --------
|
||||
aligns = [
|
||||
align_with_blanks(x0, x1)
|
||||
for x0, x1 in zip(inputs["x0_ids"], inputs["x1_ids"])
|
||||
]
|
||||
z0_list = [a["z0"] for a in aligns]
|
||||
z1_list = [a["z1"] for a in aligns]
|
||||
assert all(len(z0) == len(z1) for z0, z1 in zip(z0_list, z1_list))
|
||||
assert all(z0[0] != BLANK for z0 in z0_list) # BOS must remain
|
||||
|
||||
t = (1 - self.time_epsilon) * torch.rand(B, 1, device=device) # [B,1]
|
||||
k = self.scheduler.kappa(t).to(device) # [B,1]
|
||||
w = self.scheduler.weight(t).squeeze(1).to(device) # [B]
|
||||
if self.max_w:
|
||||
w = w.clamp(max=self.max_w)
|
||||
|
||||
# -------- 2) Sample z_t by κ-mixing (vectorized per example) --------
|
||||
# Keep python lists -> tensors per-example to reuse build_remaining_edits
|
||||
zt_list: list[list[int]] = []
|
||||
for z0, z1, kb in zip(z0_list, z1_list, k.squeeze(1).tolist()):
|
||||
# per-column Bernoulli(κ) mix; BOS is equal in z0/z1 so it stays BOS
|
||||
choose_target = torch.rand(len(z0)) < kb
|
||||
zt = [b if choose_target[j] else a for j, (a, b) in enumerate(zip(z0, z1))]
|
||||
zt_list.append(zt)
|
||||
|
||||
# -------- 3) Strip blanks to x_t and compute remaining edits --------
|
||||
xt_list = [strip_blanks(zt) for zt in zt_list]
|
||||
edits_list: list[list[Edit]] = [
|
||||
build_remaining_edits(zt, z1) for zt, z1 in zip(zt_list, z1_list)
|
||||
]
|
||||
|
||||
# -------- 4) Collate x_t for the model --------
|
||||
x_tok, x_mask = pad_1d(
|
||||
xt_list, pad_val=self.processing_class.pad_token_id
|
||||
) # [B,Lmax], [B,Lmax]
|
||||
x_tok = x_tok.to(device)
|
||||
x_mask = x_mask.to(device)
|
||||
|
||||
# -------- 5) Forward pass --------
|
||||
out = model(input_ids=x_tok, attention_mask=x_mask, t=t.to(device))
|
||||
# Rename for clarity: model returns normalized rates (no kappa)
|
||||
sub_rate_hat = out["sub_rate_hat"] # [B,L]
|
||||
del_rate_hat = out["del_rate_hat"] # [B,L]
|
||||
ins_rate_hat = out["ins_rate_hat"] # [B,L]
|
||||
logQ_sub = F.log_softmax(out["sub_logits"], dim=-1) # [B,L,V]
|
||||
logQ_ins = F.log_softmax(out["ins_logits"], dim=-1) # [B,L,V]
|
||||
|
||||
# *** NEW: zero-cost anchor to "touch" every head even if unused this step ***
|
||||
# Using .sum() * 0.0 keeps a graph dependency without changing the loss value.
|
||||
# Include both logits (for SUB/INS heads) and rates (for SUB/DEL/INS heads).
|
||||
# This is important for Deepspeed ZeRO stage 2/3 to avoid skipping unused parameters.
|
||||
anchor = (
|
||||
sub_rate_hat.sum() * 0.0
|
||||
+ del_rate_hat.sum() * 0.0
|
||||
+ ins_rate_hat.sum() * 0.0
|
||||
+ logQ_sub.sum() * 0.0
|
||||
+ logQ_ins.sum() * 0.0
|
||||
)
|
||||
|
||||
# Utility
|
||||
def safe_log(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.log(x.clamp_min(1e-12))
|
||||
|
||||
# -------- 6) Survival term --------
|
||||
# Survival = E[sum of true intensities over valid rows]
|
||||
# true intensity = w[b] * rate_hat[b, i]
|
||||
mask_f = x_mask.float()
|
||||
# L = mask_f.sum(dim=1).clamp_min(1.0) # [B] number of positions (incl. BOS)
|
||||
L1 = torch.tensor(
|
||||
[len(x1) for x1 in inputs["x1_ids"]], device=device, dtype=torch.float
|
||||
).clamp_min(1.0)
|
||||
denom = L1 if self.normalize_per_position else torch.ones_like(L1)
|
||||
|
||||
Lambda_hat = ((sub_rate_hat + del_rate_hat + ins_rate_hat) * mask_f).sum(
|
||||
dim=1
|
||||
) # [B]
|
||||
loss_surv = ((w * Lambda_hat) / denom).mean()
|
||||
|
||||
# -------- 7) Positive edit terms --------
|
||||
# For each remaining edit e: -log true rate(e) - log token prob(e) if tokenized
|
||||
# loss_pos_per = sub_rate_hat.new_zeros(B) # [B]
|
||||
# for b, edits in enumerate(edits_list):
|
||||
# if not edits:
|
||||
# continue
|
||||
# cur_len = int(x_mask[b].sum().item())
|
||||
# for e in edits:
|
||||
# pos = e.pos
|
||||
# assert 0 <= pos < cur_len, f"pos {pos} out of range {cur_len}"
|
||||
# if e.kind == "SUB":
|
||||
# loss_pos_per[b] -= logQ_sub[b, pos, e.token] + safe_log(
|
||||
# sub_rate_hat[b, pos]
|
||||
# )
|
||||
# elif e.kind == "DEL":
|
||||
# loss_pos_per[b] -= safe_log(del_rate_hat[b, pos])
|
||||
# else: # "INS"
|
||||
# loss_pos_per[b] -= logQ_ins[b, pos, e.token] + safe_log(
|
||||
# ins_rate_hat[b, pos]
|
||||
# )
|
||||
|
||||
# -------- 7) Positive edit terms (vectorized) --------
|
||||
pos_sub, tok_sub, pos_ins, tok_ins, pos_del = [], [], [], [], []
|
||||
for b, edits in enumerate(edits_list):
|
||||
cur_len = int(x_mask[b].sum().item())
|
||||
ps, ts, pi, ti, pd = [], [], [], [], []
|
||||
for e in edits:
|
||||
if not (0 <= e.pos < cur_len):
|
||||
raise AssertionError(
|
||||
f"pos {e.pos} out of range {cur_len} for b={b}"
|
||||
)
|
||||
if e.kind == "SUB":
|
||||
ps.append(e.pos)
|
||||
ts.append(e.token)
|
||||
elif e.kind == "INS":
|
||||
pi.append(e.pos)
|
||||
ti.append(e.token)
|
||||
else:
|
||||
pd.append(e.pos)
|
||||
pos_sub.append(
|
||||
torch.tensor(ps, device=x_tok.device, dtype=torch.long) if ps else None
|
||||
)
|
||||
tok_sub.append(
|
||||
torch.tensor(ts, device=x_tok.device, dtype=torch.long) if ts else None
|
||||
)
|
||||
pos_ins.append(
|
||||
torch.tensor(pi, device=x_tok.device, dtype=torch.long) if pi else None
|
||||
)
|
||||
tok_ins.append(
|
||||
torch.tensor(ti, device=x_tok.device, dtype=torch.long) if ti else None
|
||||
)
|
||||
pos_del.append(
|
||||
torch.tensor(pd, device=x_tok.device, dtype=torch.long) if pd else None
|
||||
)
|
||||
|
||||
loss_pos_terms = []
|
||||
for b in range(B):
|
||||
lp = x_tok.new_zeros(())
|
||||
if pos_sub[b] is not None:
|
||||
lp = (
|
||||
lp
|
||||
- (
|
||||
logQ_sub[b, pos_sub[b], tok_sub[b]]
|
||||
+ safe_log(sub_rate_hat[b, pos_sub[b]])
|
||||
).sum()
|
||||
)
|
||||
if pos_ins[b] is not None:
|
||||
lp = (
|
||||
lp
|
||||
- (
|
||||
logQ_ins[b, pos_ins[b], tok_ins[b]]
|
||||
+ safe_log(ins_rate_hat[b, pos_ins[b]])
|
||||
).sum()
|
||||
)
|
||||
if pos_del[b] is not None:
|
||||
lp = lp - safe_log(del_rate_hat[b, pos_del[b]]).sum()
|
||||
loss_pos_terms.append(lp)
|
||||
loss_pos_per = torch.stack(loss_pos_terms) # [B]
|
||||
|
||||
# # Average positive term per sequence (MC estimator across batch)
|
||||
loss_pos = ((w * loss_pos_per) / denom).mean()
|
||||
|
||||
# -------- 8) Total --------
|
||||
loss = loss_surv + loss_pos + anchor
|
||||
return (loss, out) if return_outputs else loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
218
dllm/dllm/pipelines/editflow/utils.py
Normal file
@ -0,0 +1,218 @@
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Text
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from dllm.utils.utils import parse_spec
|
||||
|
||||
|
||||
# ------------------------------- Collator (x0 source) --------------------------------
|
||||
@dataclass
|
||||
class X0Sampler:
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list[int]:
|
||||
raise NotImplementedError("Subclasses must implement __call__.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleX0Empty(X0Sampler):
|
||||
"""Return BOS-only (i.e., empty tail)."""
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list[int]:
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleX0Masks(X0Sampler):
|
||||
"""Return a run of mask tokens of given length."""
|
||||
|
||||
length: int = 128
|
||||
tokenizer: transformers.PreTrainedTokenizer = None
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list[int]:
|
||||
mask_id = getattr(self.tokenizer, "mask_token_id", None)
|
||||
if mask_id is None:
|
||||
raise ValueError("tokenizer needs mask_token_id for mask-based sampler")
|
||||
return [int(mask_id)] * self.length
|
||||
|
||||
|
||||
# ---------------- Factory ---------------- #
|
||||
_X0_SAMPLER_CLASSES: dict[str, type[X0Sampler]] = {
|
||||
"empty": SampleX0Empty,
|
||||
"masks": SampleX0Masks,
|
||||
}
|
||||
|
||||
|
||||
def make_x0_sampler(name: str, tokenizer: Any, **kwargs) -> X0Sampler:
|
||||
try:
|
||||
name, kvs = parse_spec(name)
|
||||
cls = _X0_SAMPLER_CLASSES[name.lower()]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Unknown x0 sampler '{name}'. Available: {list(_X0_SAMPLER_CLASSES)}"
|
||||
)
|
||||
# merged_kwargs = {**kvs, **kwargs}
|
||||
return cls(tokenizer=tokenizer, **kvs, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditFlowCollator:
|
||||
tokenizer: transformers.PreTrainedTokenizer = None
|
||||
x0_sampler: Callable | str | None = X0Sampler # can be func OR name
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.x0_sampler, str):
|
||||
self.x0_sampler = make_x0_sampler(self.x0_sampler, self.tokenizer)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, list[Any]]:
|
||||
if not features:
|
||||
return {}
|
||||
|
||||
keys = features[0].keys()
|
||||
batch = {k: [ex[k] for ex in features] for k in keys}
|
||||
batch["x1_ids"] = batch["input_ids"]
|
||||
|
||||
if "prompt_len" not in batch:
|
||||
assert self.tokenizer.bos_token_id is not None
|
||||
bos = self.tokenizer.bos_token_id
|
||||
batch["x1_ids"] = [
|
||||
x if x and x[0] == bos else [bos] + x for x in batch["x1_ids"]
|
||||
]
|
||||
batch["x0_ids"] = [
|
||||
x1_ids[:1] + self.x0_sampler(x1_ids=x1_ids[1:])
|
||||
for x1_ids in batch["x1_ids"]
|
||||
]
|
||||
else:
|
||||
batch["x0_ids"] = [
|
||||
x1_ids[:prompt_len] + self.x0_sampler(x1_ids=x1_ids[prompt_len:])
|
||||
for x1_ids, prompt_len in zip(batch["x1_ids"], batch["prompt_len"])
|
||||
]
|
||||
|
||||
batch["return_loss"] = True
|
||||
return batch
|
||||
|
||||
|
||||
# ------------------------------- Trainer utils --------------------------------
|
||||
def pad_1d(
|
||||
batch_lists: list[list[int]], pad_val: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pads a list of variable-length integer lists into:
|
||||
- out: tensor of shape [B, Lmax] with padding value `pad_val`
|
||||
- mask: tensor of shape [B, Lmax] with 1 for real tokens and 0 for padding (int mask)
|
||||
"""
|
||||
B = len(batch_lists)
|
||||
Lmax = max((len(x) for x in batch_lists), default=0)
|
||||
out = torch.full((B, Lmax), pad_val, dtype=torch.long)
|
||||
mask = torch.zeros((B, Lmax), dtype=torch.long) # 0/1 mask (int)
|
||||
|
||||
for b, x in enumerate(batch_lists):
|
||||
if not x:
|
||||
continue
|
||||
L = len(x)
|
||||
out[b, :L] = torch.tensor(x, dtype=torch.long)
|
||||
mask[b, :L] = 1 # mark valid positions with 1
|
||||
|
||||
return out, mask
|
||||
|
||||
|
||||
def init_editflow_from_src(
|
||||
ef_model, src_model, lm_head_key: str = "lm_head", verbose: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize an EditFlowModel (ef_model) from a pretrained source model.
|
||||
|
||||
If DeepSpeed ZeRO-3 is enabled (detected via HF's `is_deepspeed_zero3_enabled()`),
|
||||
this function temporarily gathers full parameters for both models on rank 0,
|
||||
performs the copy there, and then returns to sharded mode automatically.
|
||||
Otherwise it behaves like a normal CPU/GPU single-process copy.
|
||||
|
||||
Returns (missing_keys, unexpected_keys) from load_state_dict(strict=False).
|
||||
"""
|
||||
import deepspeed
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
dist_ok = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
rank = torch.distributed.get_rank() if dist_ok else 0
|
||||
|
||||
def _copy_once():
|
||||
src_sd = src_model.state_dict()
|
||||
tgt_sd = ef_model.state_dict()
|
||||
new_sd = OrderedDict()
|
||||
|
||||
# 1) copy matching backbone tensors
|
||||
for k, v in src_sd.items():
|
||||
if k in tgt_sd and tgt_sd[k].shape == v.shape:
|
||||
new_sd[k] = v
|
||||
|
||||
# 2) duplicate lm_head -> sub_logits & ins_logits (weight + optional bias)
|
||||
lm_w = f"{lm_head_key}.weight"
|
||||
lm_b = f"{lm_head_key}.bias"
|
||||
|
||||
if lm_w in src_sd:
|
||||
if "sub_logits.weight" in tgt_sd:
|
||||
new_sd["sub_logits.weight"] = src_sd[lm_w]
|
||||
if "ins_logits.weight" in tgt_sd:
|
||||
new_sd["ins_logits.weight"] = src_sd[lm_w]
|
||||
if lm_b in src_sd:
|
||||
if "sub_logits.bias" in tgt_sd:
|
||||
new_sd["sub_logits.bias"] = src_sd[lm_b]
|
||||
if "ins_logits.bias" in tgt_sd:
|
||||
new_sd["ins_logits.bias"] = src_sd[lm_b]
|
||||
|
||||
# 3) non-strict load so new rate heads remain randomly initialized
|
||||
missing, unexpected = ef_model.load_state_dict(new_sd, strict=False)
|
||||
return new_sd, missing, unexpected
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
# All ranks enter/exit together; only rank 0 materializes full tensors.
|
||||
params = list(ef_model.parameters()) + list(src_model.parameters())
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
if rank == 0:
|
||||
new_sd, missing, unexpected = _copy_once()
|
||||
else:
|
||||
new_sd, missing, unexpected = OrderedDict(), [], []
|
||||
|
||||
if dist_ok:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if verbose and rank == 0:
|
||||
_p = getattr(globals().get("dllm", None), "utils", None)
|
||||
printer = getattr(_p, "print_main", print) if _p else print
|
||||
printer(
|
||||
f"[EditFlow init][ZeRO-3] Copied {len(new_sd)} tensors from Src Model."
|
||||
)
|
||||
if missing:
|
||||
printer(" Missing (expected for new rate heads, etc.):")
|
||||
for k in missing:
|
||||
printer(" -", k)
|
||||
if unexpected:
|
||||
printer(" Unexpected (check key names):")
|
||||
for k in unexpected:
|
||||
printer(" -", k)
|
||||
return missing, unexpected
|
||||
|
||||
# --- Non-ZeRO (or DS not present) path ---
|
||||
new_sd, missing, unexpected = _copy_once()
|
||||
if verbose:
|
||||
_p = getattr(globals().get("dllm", None), "utils", None)
|
||||
printer = getattr(_p, "print_main", print) if _p else print
|
||||
printer(f"[EditFlow init] Copied {len(new_sd)} tensors from Src Model.")
|
||||
if missing:
|
||||
printer(" Missing (expected for new rate heads, etc.):")
|
||||
for k in missing:
|
||||
printer(" -", k)
|
||||
if unexpected:
|
||||
printer(" Unexpected (check key names):")
|
||||
for k in unexpected:
|
||||
printer(" -", k)
|
||||
return missing, unexpected
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
7
dllm/dllm/pipelines/llada/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from . import generator, trainer
|
||||
from .models.modeling_llada import LLaDAModelLM
|
||||
from .models.configuration_llada import LLaDAConfig
|
||||
from .models.modeling_lladamoe import LLaDAMoEModelLM
|
||||
from .models.configuration_lladamoe import LLaDAMoEConfig
|
||||
from .generator import LLaDAGeneratorConfig, LLaDAGenerator
|
||||
from .trainer import LLaDATrainer
|
||||
357
dllm/dllm/pipelines/llada/eval.py
Normal file
@ -0,0 +1,357 @@
|
||||
"""
|
||||
accelerate launch \
|
||||
--num_processes 2 \
|
||||
dllm/pipelines/llada/eval.py \
|
||||
--tasks gsm8k \
|
||||
--model llada \
|
||||
--num_fewshot 8 \
|
||||
--model_args "pretrained=GSAI-ML/LLaDA-8B-Base,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.api.registry import register_model
|
||||
from lm_eval.models.utils import get_dtype
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDAEvalConfig(LLaDAGeneratorConfig):
|
||||
max_new_tokens: int = 1024
|
||||
max_length: int = 4096
|
||||
steps: int = 1024
|
||||
block_length: int = 1024
|
||||
|
||||
pretrained: str = ""
|
||||
dtype: str | torch.dtype = "auto"
|
||||
batch_size: int = 32
|
||||
mc_num: int = 128
|
||||
is_check_greedy: bool = True
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@register_model("llada")
|
||||
class LLaDAEvalHarness(LM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LLaDAEvalConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = LLaDAEvalConfig()
|
||||
|
||||
# Pull args from config, allow kwargs to override
|
||||
pretrained = kwargs.get("pretrained", config.pretrained)
|
||||
dtype = kwargs.get("dtype", config.dtype)
|
||||
batch_size = kwargs.get("batch_size", config.batch_size)
|
||||
mc_num = kwargs.get("mc_num", config.mc_num)
|
||||
is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy)
|
||||
device = kwargs.get("device", config.device)
|
||||
cfg = kwargs.get("cfg", config.cfg_scale)
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
block_length = kwargs.get("block_length", config.block_length)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
remasking = kwargs.get("remasking", config.remasking)
|
||||
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
# Get GLOBAL rank from torch.distributed (not accelerator)
|
||||
if torch.distributed.is_initialized():
|
||||
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
|
||||
self._world_size = (
|
||||
torch.distributed.get_world_size()
|
||||
) # ← GLOBAL world size (16)
|
||||
else:
|
||||
self._rank = 0
|
||||
self._world_size = 1
|
||||
|
||||
# Use accelerator for device placement
|
||||
self.model = dllm.utils.get_model(
|
||||
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# Let accelerator handle device placement
|
||||
self.model = accelerator.prepare(self.model)
|
||||
self.device = (
|
||||
accelerator.device
|
||||
) # ← Accelerator figures out local device correctly
|
||||
self.accelerator = accelerator
|
||||
else:
|
||||
# Single GPU
|
||||
self.model = self.model.to(device)
|
||||
self.device = torch.device(device)
|
||||
self.accelerator = None
|
||||
|
||||
self.tokenizer = dllm.utils.get_tokenizer(
|
||||
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
|
||||
)
|
||||
|
||||
# generation params
|
||||
self.mask_id = self.tokenizer.mask_token_id
|
||||
self.batch_size = int(batch_size)
|
||||
self.max_length = max_length
|
||||
self.max_new_tokens = int(max_new_tokens)
|
||||
self.block_length = int(block_length)
|
||||
self.steps = int(steps)
|
||||
self.cfg = float(cfg)
|
||||
self.remasking = remasking
|
||||
self.is_check_greedy = is_check_greedy
|
||||
|
||||
# loglikelihood params
|
||||
self.mc_num = int(mc_num)
|
||||
assert mc_num % self.batch_size == 0
|
||||
self.sampling_eps = 0.0
|
||||
|
||||
def apply_chat_template(
|
||||
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Method to apply a chat template to a list of chat history between user and model.
|
||||
"""
|
||||
chat_templated = self.tokenizer.apply_chat_template(
|
||||
chat_history,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=not add_generation_prompt,
|
||||
)
|
||||
return chat_templated
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self.tokenizer.name_or_path.replace("/", "__")
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def _forward_process(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
b, l = batch.shape
|
||||
|
||||
target_len = (l - prompt_index.sum()).item()
|
||||
k = torch.randint(1, target_len + 1, (), device=batch.device)
|
||||
|
||||
x = torch.round(
|
||||
torch.linspace(
|
||||
float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device
|
||||
)
|
||||
).long()
|
||||
x = ((x - 1) % target_len) + 1
|
||||
assert x.min() >= 1 and x.max() <= target_len
|
||||
|
||||
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
|
||||
is_mask = indices < x.unsqueeze(1)
|
||||
|
||||
for i in range(b):
|
||||
is_mask[i] = is_mask[i][torch.randperm(target_len)]
|
||||
|
||||
is_mask = torch.cat(
|
||||
(
|
||||
torch.zeros(
|
||||
b, prompt_index.sum(), dtype=torch.bool, device=batch.device
|
||||
),
|
||||
is_mask,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
noisy_batch = torch.where(is_mask, self.mask_id, batch)
|
||||
|
||||
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_logits(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if self.cfg > 0.0:
|
||||
assert len(prompt_index) == batch.shape[1]
|
||||
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
||||
un_batch = batch.clone()
|
||||
un_batch[prompt_index] = self.mask_id
|
||||
batch = torch.cat([batch, un_batch])
|
||||
|
||||
logits = self.model(batch).logits
|
||||
|
||||
if self.cfg > 0.0:
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
|
||||
return logits[:, : batch.shape[1]]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
|
||||
seq = torch.concatenate([prefix, target])[None, :]
|
||||
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
|
||||
loss_acc = []
|
||||
for _ in range(self.mc_num // self.batch_size):
|
||||
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
|
||||
|
||||
mask_indices = perturbed_seq == self.mask_id
|
||||
|
||||
logits = self.get_logits(perturbed_seq, prompt_index)
|
||||
|
||||
loss = (
|
||||
F.cross_entropy(
|
||||
logits[mask_indices], seq[mask_indices], reduction="none"
|
||||
)
|
||||
/ p_mask[mask_indices]
|
||||
)
|
||||
loss = loss.sum() / self.batch_size
|
||||
loss_acc.append(loss.item())
|
||||
|
||||
return -sum(loss_acc) / len(loss_acc)
|
||||
|
||||
@torch.no_grad()
|
||||
def suffix_greedy_prediction(
|
||||
self, prefix: torch.Tensor, target: torch.Tensor
|
||||
) -> bool:
|
||||
if not self.is_check_greedy:
|
||||
return False
|
||||
|
||||
seq = torch.full(
|
||||
(1, len(prefix) + len(target)), self.mask_id, device=self.device
|
||||
)
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
prefix, target = prefix.to(self.device), target.to(self.device)
|
||||
seq[0, : len(prefix)] = prefix
|
||||
|
||||
for i in range(len(target)):
|
||||
mask_index = seq == self.mask_id
|
||||
logits = self.get_logits(seq, prompt_index)[mask_index]
|
||||
x0 = torch.argmax(logits, dim=-1)
|
||||
|
||||
p = torch.softmax(logits.to(torch.float32), dim=-1)
|
||||
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(
|
||||
dim=-1
|
||||
)
|
||||
_, index = torch.sort(confidence, descending=True)
|
||||
x0[index[1:]] = self.mask_id
|
||||
seq[mask_index] = x0.clone()
|
||||
correct = target == seq[0, len(prefix) :]
|
||||
correct = torch.all(correct)
|
||||
return correct
|
||||
|
||||
def _encode_pair(
|
||||
self, context: str, continuation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
n_spaces = len(context) - len(context.rstrip())
|
||||
if n_spaces > 0:
|
||||
continuation = context[-n_spaces:] + continuation
|
||||
context = context[:-n_spaces]
|
||||
|
||||
whole_enc = self.tokenizer(context + continuation)["input_ids"]
|
||||
context_enc = self.tokenizer(context)["input_ids"]
|
||||
|
||||
context_enc_len = len(context_enc)
|
||||
continuation_enc = whole_enc[context_enc_len:]
|
||||
|
||||
return context_enc, continuation_enc
|
||||
|
||||
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
||||
def _tokenize(e):
|
||||
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
||||
return {
|
||||
"prefix_text": e["prefix"],
|
||||
"target_text": e["target"],
|
||||
"prefix": prefix,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
ds = []
|
||||
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
|
||||
|
||||
assert max(prompt_len) <= 4096
|
||||
|
||||
out = []
|
||||
with torch.no_grad():
|
||||
for elem in tqdm(ds, desc="Computing likelihood..."):
|
||||
prefix = elem["prefix"]
|
||||
target = elem["target"]
|
||||
|
||||
ll = self.get_loglikelihood(prefix, target)
|
||||
|
||||
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
|
||||
|
||||
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
||||
torch.cuda.empty_cache()
|
||||
return out
|
||||
|
||||
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_until(self, requests: list[Instance]) -> list[str]:
|
||||
def _tokenize(e):
|
||||
return {
|
||||
"question": self.tokenizer(e["question"])["input_ids"],
|
||||
"question_text": e["question"],
|
||||
"until": e["until"],
|
||||
}
|
||||
|
||||
ds = [
|
||||
{"question": req.args[0], "until": req.args[1]["until"]} for req in requests
|
||||
]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
|
||||
out = []
|
||||
generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer)
|
||||
|
||||
for elem in tqdm(ds, desc="Generating..."):
|
||||
prompt = [elem["question"].to(self.device)]
|
||||
stop_tokens = elem["until"]
|
||||
generated_ids = generator.generate(
|
||||
inputs=prompt,
|
||||
steps=self.steps,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
block_length=self.block_length,
|
||||
temperature=0.0,
|
||||
cfg_scale=self.cfg,
|
||||
remasking=self.remasking,
|
||||
)
|
||||
generated_answer = self.tokenizer.decode(
|
||||
generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False
|
||||
)
|
||||
for stop_seq in stop_tokens:
|
||||
if stop_seq in generated_answer:
|
||||
generated_answer = generated_answer.split(stop_seq)[0]
|
||||
|
||||
# remove special tokens
|
||||
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
|
||||
generated_answer = self.tokenizer.decode(
|
||||
generated_answer_ids, skip_special_tokens=True
|
||||
)
|
||||
out.append(generated_answer)
|
||||
if self.accelerator is not None:
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_evaluate()
|
||||
379
dllm/dllm/pipelines/llada/generator.py
Normal file
@ -0,0 +1,379 @@
|
||||
"""
|
||||
reference: https://github.com/ML-GSAI/LLaDA/blob/main/generate.py
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from dllm.utils.generation_utils import get_num_transfer_tokens
|
||||
from dllm.core.generation.generator import (
|
||||
GeneratorOutput,
|
||||
GeneratorConfig,
|
||||
BaseGenerator,
|
||||
)
|
||||
|
||||
|
||||
def add_gumbel_noise(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
||||
"""
|
||||
The Gumbel max is a method for sampling categorical distributions.
|
||||
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
||||
Thus, we use float64.
|
||||
"""
|
||||
if temperature == 0:
|
||||
return logits
|
||||
logits = logits.to(torch.float64)
|
||||
noise = torch.rand_like(logits, dtype=torch.float64)
|
||||
gumbel_noise = (-torch.log(noise)) ** temperature
|
||||
return logits.exp() / gumbel_noise
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDAGeneratorConfig(GeneratorConfig):
|
||||
max_new_tokens: int = 128
|
||||
max_length: int = (
|
||||
None # There's no explicit length_limit except for the tokenizer/model context
|
||||
)
|
||||
block_length: int = 128
|
||||
steps: int = 128
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
stochastic_transfer: bool = False
|
||||
cfg_scale: float = 0.0
|
||||
cfg_keep_tokens: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDAGenerator(BaseGenerator):
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: list[torch.Tensor | list],
|
||||
config: LLaDAGeneratorConfig | None = None,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput | torch.Tensor:
|
||||
if config is None:
|
||||
config = LLaDAGeneratorConfig()
|
||||
|
||||
# ----- pull args from config, allow kwargs to override -----
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
block_length = kwargs.get("block_length", config.block_length)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
cfg_scale = kwargs.get("cfg_scale", config.cfg_scale)
|
||||
cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens)
|
||||
remasking = kwargs.get("remasking", config.remasking)
|
||||
stochastic_transfer = kwargs.get(
|
||||
"stochastic_transfer", config.stochastic_transfer
|
||||
)
|
||||
return_dict_in_generate = kwargs.get(
|
||||
"return_dict_in_generate", config.return_dict_in_generate
|
||||
)
|
||||
|
||||
assert 1 <= block_length
|
||||
assert 1 <= steps
|
||||
mask_id = self.tokenizer.mask_token_id
|
||||
eos_id = self.tokenizer.eos_token_id
|
||||
|
||||
# ----- Shape bookkeeping: per-sample prompt lengths and final canvas width -----
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = [
|
||||
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
|
||||
for p in inputs
|
||||
]
|
||||
prompt_lens = [p.shape[0] for p in inputs]
|
||||
|
||||
if max_new_tokens:
|
||||
max_length = max_new_tokens + max(prompt_lens)
|
||||
else:
|
||||
max_new_tokens = max_length - max(prompt_lens)
|
||||
|
||||
B = len(inputs)
|
||||
T = max_length
|
||||
|
||||
# ----- Initialize canvas with EOS, copy inputs, and append mask tail -----
|
||||
x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device)
|
||||
for i, p in enumerate(inputs):
|
||||
x[i, : prompt_lens[i]] = p # keep original prompt tokens
|
||||
x[i, prompt_lens[i] : prompt_lens[i] + max_new_tokens] = (
|
||||
mask_id # append `max_new_tokens` masks to be generated
|
||||
)
|
||||
attention_mask = (x != eos_id).long() if B > 1 else None
|
||||
|
||||
# Tokens that were *given* at the start (non-mask, non-EOS).
|
||||
# These will be masked in the unconditional forward pass for CFG.
|
||||
# Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
|
||||
unmasked_index = (x != mask_id) & (x != eos_id)
|
||||
if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
|
||||
keep_mask = torch.isin(
|
||||
x, torch.as_tensor(cfg_keep_tokens, device=self.model.device)
|
||||
)
|
||||
unmasked_index = unmasked_index & ~keep_mask
|
||||
|
||||
# ----- Block scheduling over the appended mask tail -----
|
||||
num_blocks = math.ceil(max_new_tokens / block_length)
|
||||
steps = math.ceil(steps / num_blocks) # per-block step budget
|
||||
histories = [x.clone()] if return_dict_in_generate else None
|
||||
|
||||
for b in range(num_blocks):
|
||||
# Build a per-sample mask *within this block* (aligned to each prompt's tail)
|
||||
block_mask_index = torch.zeros(
|
||||
(B, block_length), dtype=torch.bool, device=x.device
|
||||
)
|
||||
|
||||
for j in range(B):
|
||||
start = prompt_lens[j] + b * block_length
|
||||
end = min(start + block_length, prompt_lens[j] + max_new_tokens, T)
|
||||
if start < end:
|
||||
width = end - start
|
||||
block_mask_index[j, :width] = (
|
||||
x[j, start:end] == mask_id
|
||||
) # which positions in this block are still masked
|
||||
|
||||
# Decide how many tokens to reveal per step in this block
|
||||
num_transfer_tokens = get_num_transfer_tokens(
|
||||
mask_index=block_mask_index,
|
||||
steps=steps,
|
||||
scheduler=self.scheduler,
|
||||
stochastic=stochastic_transfer,
|
||||
)
|
||||
|
||||
# Some steps may be skipped if there are no transfers
|
||||
effective_steps = num_transfer_tokens.size(1)
|
||||
|
||||
# ----- Iterative reveal inside the current block -----
|
||||
for i in range(effective_steps):
|
||||
mask_index = x == mask_id # current global mask map
|
||||
|
||||
# Optional CFG: second forward where original prompt tokens are masked out
|
||||
if cfg_scale > 0.0:
|
||||
un_x = x.clone()
|
||||
un_x[unmasked_index] = mask_id
|
||||
x_ = torch.cat([x, un_x], dim=0)
|
||||
logits = self.model(
|
||||
x_, attention_mask=attention_mask
|
||||
).logits # Use attention mask here
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
||||
else:
|
||||
logits = self.model(
|
||||
x, attention_mask=attention_mask
|
||||
).logits # Use attention mask here
|
||||
|
||||
# Argmax decoding with optional Gumbel-Max noise for exploration
|
||||
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
||||
x0 = torch.argmax(
|
||||
logits_with_noise, dim=-1
|
||||
) # [B, T] predicted token ids
|
||||
|
||||
# Per-position confidence used to pick which masks to commit this step
|
||||
if remasking == "low_confidence":
|
||||
p = F.softmax(logits, dim=-1)
|
||||
x0_p = torch.squeeze(
|
||||
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
|
||||
) # [B, T] confidence of predicted token
|
||||
elif remasking == "random":
|
||||
x0_p = torch.rand(
|
||||
(x0.shape[0], x0.shape[1]), device=x0.device
|
||||
) # random scores
|
||||
else:
|
||||
raise NotImplementedError(remasking)
|
||||
|
||||
# Restrict selection window to the *current block's* tail region
|
||||
for j in range(B):
|
||||
x0_p[j, prompt_lens[j] + (b + 1) * block_length :] = -np.inf
|
||||
|
||||
# Only allow updates at currently masked positions; keep others fixed
|
||||
x0 = torch.where(mask_index, x0, x)
|
||||
confidence = torch.where(
|
||||
mask_index, x0_p, -np.inf
|
||||
) # consider masked positions only
|
||||
|
||||
# Pick exactly `num_transfer_tokens[j, i]` highest-confidence positions per sample
|
||||
transfer_index = torch.zeros_like(
|
||||
x0, dtype=torch.bool, device=x0.device
|
||||
)
|
||||
for j in range(confidence.shape[0]):
|
||||
_, select_index = torch.topk(
|
||||
confidence[j], k=num_transfer_tokens[j, i]
|
||||
)
|
||||
transfer_index[j, select_index] = True
|
||||
|
||||
# Commit chosen predictions into the canvas
|
||||
x[transfer_index] = x0[transfer_index]
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
# ----- Output format -----
|
||||
if not return_dict_in_generate:
|
||||
return x
|
||||
else:
|
||||
return GeneratorOutput(sequences=x, histories=histories)
|
||||
|
||||
@torch.no_grad()
|
||||
def infill(
|
||||
self, inputs: list[torch.Tensor | list], config, **kwargs
|
||||
) -> GeneratorOutput | torch.Tensor:
|
||||
"""
|
||||
Fill in-place the <|mdm_mask|> tokens contained in `inputs`.
|
||||
The whole (padded) sequence is split into block windows of length
|
||||
`block_length`; within each window we progressively "unmask" positions
|
||||
according to the scheduler and chosen remasking strategy.
|
||||
|
||||
Notes:
|
||||
- Right padding uses EOS.
|
||||
- CFG masks out *originally known* (non-mask, non-EOS) tokens in the
|
||||
unconditional branch, identical to `generate`.
|
||||
- Only masked positions are ever updated; non-mask tokens are left intact.
|
||||
"""
|
||||
# ----- pull args from config, allow kwargs to override -----
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
block_length = kwargs.get("block_length", config.block_length)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
cfg_scale = kwargs.get("cfg_scale", config.cfg_scale)
|
||||
cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens)
|
||||
remasking = kwargs.get("remasking", config.remasking)
|
||||
stochastic_transfer = kwargs.get(
|
||||
"stochastic_transfer", config.stochastic_transfer
|
||||
)
|
||||
return_dict_in_generate = kwargs.get(
|
||||
"return_dict_in_generate", config.return_dict_in_generate
|
||||
)
|
||||
|
||||
mask_id = self.tokenizer.mask_token_id
|
||||
eos_id = self.tokenizer.eos_token_id
|
||||
|
||||
# ----- Build canvas: right-pad with EOS to the max length in the batch -----
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = [
|
||||
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
|
||||
for p in inputs
|
||||
]
|
||||
|
||||
B = len(inputs)
|
||||
seq_lens = [t.shape[0] for t in inputs]
|
||||
T = max(seq_lens)
|
||||
|
||||
# Default to a single block spanning the whole sequence
|
||||
if block_length is None:
|
||||
block_length = T
|
||||
|
||||
assert 1 <= block_length
|
||||
assert 1 <= steps
|
||||
|
||||
x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device)
|
||||
for i, t in enumerate(inputs):
|
||||
x[i, : seq_lens[i]] = t
|
||||
attention_mask = (x != eos_id).long() if B > 1 else None
|
||||
|
||||
# Tokens that were *given* at the start (non-mask, non-EOS).
|
||||
# These will be masked in the unconditional forward pass for CFG.
|
||||
# Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
|
||||
unmasked_index = (x != mask_id) & (x != eos_id)
|
||||
if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
|
||||
keep_mask = torch.isin(
|
||||
x, torch.as_tensor(cfg_keep_tokens, device=self.model.device)
|
||||
)
|
||||
unmasked_index = unmasked_index & ~keep_mask
|
||||
|
||||
# ----- Blockwise schedule over the *entire* (padded) sequence -----
|
||||
num_blocks = math.ceil(T / block_length)
|
||||
steps_per_block = math.ceil(steps / num_blocks)
|
||||
histories = [x.clone()] if return_dict_in_generate else None
|
||||
|
||||
# Create attention mask where eos_token_id is masked (set to 0)
|
||||
attention_mask = (x != eos_id).long()
|
||||
|
||||
for b in range(num_blocks):
|
||||
start = b * block_length
|
||||
stop = min(start + block_length, T)
|
||||
|
||||
# Per-sample view of which positions in this block are masks
|
||||
block_mask_index = torch.zeros(
|
||||
(B, block_length), dtype=torch.bool, device=self.model.device
|
||||
)
|
||||
widths = []
|
||||
for j in range(B):
|
||||
# Width limited by sample's true length and sequence end
|
||||
width = max(0, min(seq_lens[j], stop) - start)
|
||||
widths.append(width)
|
||||
if width > 0:
|
||||
block_mask_index[j, :width] = x[j, start : start + width] == mask_id
|
||||
|
||||
# Decide how many tokens to reveal at each step in this block
|
||||
num_transfer_tokens = get_num_transfer_tokens(
|
||||
mask_index=block_mask_index,
|
||||
steps=steps_per_block,
|
||||
scheduler=self.scheduler,
|
||||
stochastic=stochastic_transfer,
|
||||
)
|
||||
|
||||
# Some blocks may have no masks => effective_steps == 0
|
||||
effective_steps = num_transfer_tokens.size(1)
|
||||
|
||||
for s in range(effective_steps):
|
||||
mask_index_full = x == mask_id
|
||||
|
||||
# ----- Forward pass (+ optional CFG) -----
|
||||
if cfg_scale > 0.0:
|
||||
un_x = x.clone()
|
||||
un_x[unmasked_index] = mask_id
|
||||
x_ = torch.cat([x, un_x], dim=0)
|
||||
logits = self.model(
|
||||
x_, attention_mask=attention_mask
|
||||
).logits # Use attention mask here
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
||||
else:
|
||||
logits = self.model(
|
||||
x, attention_mask=attention_mask
|
||||
).logits # Use attention mask here
|
||||
|
||||
# Greedy with optional Gumbel-Max noise
|
||||
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
||||
x0 = torch.argmax(logits_with_noise, dim=-1) # [B, T]
|
||||
|
||||
# Confidence used for choosing which masks to commit this step
|
||||
if remasking == "low_confidence":
|
||||
p = F.softmax(logits, dim=-1)
|
||||
x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(
|
||||
-1
|
||||
) # [B, T]
|
||||
elif remasking == "random":
|
||||
x0_p = torch.rand((B, T), device=self.model.device)
|
||||
else:
|
||||
raise NotImplementedError(remasking)
|
||||
|
||||
# Restrict selection to the *current* block only
|
||||
for j in range(B):
|
||||
end_j = start + widths[j]
|
||||
# Outside current block => impossible to select
|
||||
x0_p[j, :start] = -np.inf
|
||||
x0_p[j, end_j:] = -np.inf
|
||||
|
||||
# Only consider currently-masked positions as candidates
|
||||
x0 = torch.where(mask_index_full, x0, x)
|
||||
confidence = torch.where(mask_index_full, x0_p, -np.inf)
|
||||
|
||||
# Pick exactly num_transfer_tokens[j, s] positions per sample
|
||||
transfer_index = torch.zeros_like(x, dtype=torch.bool)
|
||||
for j in range(B):
|
||||
k = int(num_transfer_tokens[j, s].item())
|
||||
if k > 0:
|
||||
_, select_idx = torch.topk(confidence[j], k=k)
|
||||
transfer_index[j, select_idx] = True
|
||||
|
||||
# Commit selected predictions into the canvas
|
||||
x[transfer_index] = x0[transfer_index]
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
# ----- Output format -----
|
||||
if not return_dict_in_generate:
|
||||
return x
|
||||
else:
|
||||
return GeneratorOutput(sequences=x, histories=histories)
|
||||
19
dllm/dllm/pipelines/llada/models/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
from .configuration_llada import LLaDAConfig
|
||||
from .modeling_llada import LLaDAModelLM
|
||||
from .configuration_lladamoe import LLaDAMoEConfig
|
||||
from .modeling_lladamoe import LLaDAMoEModelLM
|
||||
|
||||
# Register with HuggingFace Auto classes for local usage
|
||||
try:
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
|
||||
|
||||
AutoConfig.register("llada", LLaDAConfig)
|
||||
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|
||||
AutoModelForMaskedLM.register(LLaDAConfig, LLaDAModelLM)
|
||||
|
||||
AutoConfig.register("lladamoe", LLaDAMoEConfig)
|
||||
AutoModel.register(LLaDAMoEConfig, LLaDAMoEModelLM)
|
||||
AutoModelForMaskedLM.register(LLaDAMoEConfig, LLaDAMoEModelLM)
|
||||
except ImportError:
|
||||
# transformers not available or Auto classes not imported
|
||||
pass
|
||||
459
dllm/dllm/pipelines/llada/models/configuration_llada.py
Normal file
@ -0,0 +1,459 @@
|
||||
"""
|
||||
LLaDA configuration
|
||||
"""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from enum import Enum
|
||||
from os import PathLike
|
||||
from typing import Union
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ActivationType",
|
||||
"ActivationCheckpointingStrategy",
|
||||
"BlockType",
|
||||
"LayerNormType",
|
||||
"InitFnType",
|
||||
"ModelConfig",
|
||||
]
|
||||
|
||||
PathOrStr = Union[str, PathLike]
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""
|
||||
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
||||
We include this here for compatibility with older version of Python.
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"'{str(self)}'"
|
||||
|
||||
|
||||
class LayerNormType(StrEnum):
|
||||
default = "default"
|
||||
"""
|
||||
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
||||
"""
|
||||
|
||||
low_precision = "low_precision"
|
||||
"""
|
||||
A low-precision version of the default LayerNorm.
|
||||
"""
|
||||
|
||||
rms = "rms"
|
||||
"""
|
||||
An RMSNorm implementation. When using ``torch.compile`` this is
|
||||
probably the fastest implementation.
|
||||
"""
|
||||
|
||||
gemma_rms = "gemma_rms"
|
||||
"""
|
||||
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
|
||||
probably the fastest implementation.
|
||||
"""
|
||||
|
||||
amd_compatible = "amd_compatible"
|
||||
"""
|
||||
LayerNorm implemented manually to work around an issue with ROCm.
|
||||
"""
|
||||
|
||||
|
||||
class ActivationType(StrEnum):
|
||||
gelu = "gelu"
|
||||
relu = "relu"
|
||||
silu = "silu"
|
||||
swiglu = "swiglu"
|
||||
|
||||
|
||||
class BlockType(StrEnum):
|
||||
sequential = "sequential"
|
||||
parallel = "parallel"
|
||||
|
||||
llama = "llama"
|
||||
"""
|
||||
A block similar to the sequential block with slightly different
|
||||
implementations of operations like attention to imitate the behavior of Llama.
|
||||
"""
|
||||
|
||||
|
||||
class InitFnType(StrEnum):
|
||||
mitchell = "mitchell"
|
||||
"""
|
||||
The strategy suggested to us by Mitchell Wortsman from UW.
|
||||
This uses a truncated normal distribution with an adaptive standard deviation that depends
|
||||
on the size of the weights as well as the depth of the layer.
|
||||
"""
|
||||
|
||||
normal = "normal"
|
||||
"""
|
||||
All weights are initialized from the same normal distribution.
|
||||
"""
|
||||
|
||||
kaiming_normal = "kaiming_normal"
|
||||
"""
|
||||
All weights are initialized with the Kaiming method from a normal distribution.
|
||||
Note this currently won't work with FSDP.
|
||||
"""
|
||||
|
||||
fan_in = "fan_in"
|
||||
"""
|
||||
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
||||
is the input dimensionality of the kernel.
|
||||
"""
|
||||
|
||||
full_megatron = "full_megatron"
|
||||
"""
|
||||
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig():
|
||||
"""
|
||||
LLaDA (model) configuration.
|
||||
"""
|
||||
|
||||
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
||||
|
||||
d_model: int = 768
|
||||
"""
|
||||
The hidden size of the model.
|
||||
"""
|
||||
|
||||
n_heads: int = 12
|
||||
"""
|
||||
The number of self-attention heads.
|
||||
"""
|
||||
|
||||
n_kv_heads: Optional[int] = None
|
||||
"""
|
||||
The number of heads to use for keys and values. Defaults to `n_heads`.
|
||||
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
||||
Set this to 1 for multi-query attention.
|
||||
Set it to some in-between value for Llama2-style grouped query attention.
|
||||
"""
|
||||
|
||||
n_layers: int = 12
|
||||
"""
|
||||
The number of layers/blocks.
|
||||
"""
|
||||
|
||||
mlp_ratio: int = 4
|
||||
"""
|
||||
The ratio of the inner MLP dimensionality to ``d_model``.
|
||||
This is only used when ``mlp_hidden_size`` is not set.
|
||||
"""
|
||||
|
||||
mlp_hidden_size: Optional[int] = None
|
||||
"""
|
||||
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
||||
"""
|
||||
|
||||
activation_type: ActivationType = ActivationType.swiglu
|
||||
"""
|
||||
The activation function to use within the MLP layers.
|
||||
"""
|
||||
|
||||
block_type: BlockType = BlockType.sequential
|
||||
"""
|
||||
The transformer block implementation.
|
||||
"""
|
||||
|
||||
block_group_size: int = 1
|
||||
"""
|
||||
The number of blocks to group together into a single parent block.
|
||||
This has no affect on the number of parameters in the model and is only used to wrap groups
|
||||
of blocks together with a single FSDP wrapper during training.
|
||||
"""
|
||||
|
||||
alibi: bool = False
|
||||
"""
|
||||
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
||||
"""
|
||||
|
||||
alibi_bias_max: float = 8.0
|
||||
"""
|
||||
Maximum absolute value of ALiBi bias.
|
||||
"""
|
||||
|
||||
rope: bool = False
|
||||
"""
|
||||
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
||||
"""
|
||||
|
||||
rope_full_precision: bool = True
|
||||
"""
|
||||
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
||||
apply RoPE at the precision of the input.
|
||||
"""
|
||||
|
||||
flash_attention: bool = False
|
||||
"""
|
||||
If ``True``, use ``FlashAttention``.
|
||||
"""
|
||||
|
||||
attention_dropout: float = 0.1
|
||||
"""
|
||||
The dropout probability within the attention modules.
|
||||
"""
|
||||
|
||||
multi_query_attention: Optional[bool] = None
|
||||
"""
|
||||
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
|
||||
and is more efficient during inference.
|
||||
"""
|
||||
|
||||
attention_layer_norm: bool = False
|
||||
"""
|
||||
Apply layer norm to the keys and queries within the attention mechanism.
|
||||
This can help stabilize training.
|
||||
"""
|
||||
|
||||
residual_dropout: float = 0.1
|
||||
"""
|
||||
The dropout probability for the MLP and attention output within each block.
|
||||
"""
|
||||
|
||||
embedding_dropout: float = 0.1
|
||||
"""
|
||||
The dropout probability for embeddings.
|
||||
"""
|
||||
|
||||
input_emb_norm: bool = False
|
||||
"""
|
||||
An input hidden_states norm implementation by gemmma.
|
||||
"""
|
||||
|
||||
layer_norm_type: LayerNormType = LayerNormType.default
|
||||
"""
|
||||
The layernorm implementation to use.
|
||||
"""
|
||||
|
||||
layer_norm_with_affine: bool = True
|
||||
"""
|
||||
Whether to include bias and weight parameters for the layer norms.
|
||||
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
||||
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
||||
to ``False``.
|
||||
"""
|
||||
|
||||
rms_norm_eps: float = 1e-05
|
||||
"""
|
||||
The rms layernorm eps param.
|
||||
"""
|
||||
|
||||
attention_layer_norm_with_affine: bool = True
|
||||
"""
|
||||
Toggle affine transform for the QK norms.
|
||||
"""
|
||||
|
||||
max_sequence_length: int = 1024
|
||||
"""
|
||||
The maximum input sequence length supported by the model.
|
||||
"""
|
||||
|
||||
rope_theta: float = 10000.0
|
||||
"""
|
||||
The rope base param.
|
||||
"""
|
||||
|
||||
include_qkv_bias: Optional[bool] = False
|
||||
"""
|
||||
Whether or not to include bias parameters in qkv linear layers.
|
||||
"""
|
||||
|
||||
include_bias: bool = False
|
||||
"""
|
||||
Whether or not to include bias parameters in linear layers.
|
||||
In PaLM, they got rid of all bias terms because they found that large
|
||||
models tend to have near 0 bias terms anyway.
|
||||
"""
|
||||
|
||||
bias_for_layer_norm: Optional[bool] = None
|
||||
"""
|
||||
Whether or not to include bias parameters in layer norm.
|
||||
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
||||
layer norm.
|
||||
When this is None (the default), it inherits the setting from include_bias.
|
||||
"""
|
||||
|
||||
scale_logits: bool = False
|
||||
"""
|
||||
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
||||
"""
|
||||
|
||||
vocab_size: int = 50257
|
||||
"""
|
||||
Vocabulary size of the model.
|
||||
"""
|
||||
|
||||
embedding_size: Optional[int] = 50304
|
||||
"""
|
||||
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
||||
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
||||
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
||||
substantially.
|
||||
"""
|
||||
|
||||
weight_tying: bool = True
|
||||
"""
|
||||
Whether to tie output linear weights to the input embedding.
|
||||
"""
|
||||
|
||||
eos_token_id: int = 50256
|
||||
"""
|
||||
The ID of the end-of-sentence special token.
|
||||
"""
|
||||
|
||||
pad_token_id: int = 50256
|
||||
"""
|
||||
The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
||||
"""
|
||||
|
||||
mask_token_id: Optional[int] = 50256
|
||||
"""
|
||||
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
|
||||
"""
|
||||
|
||||
init_device: Optional[str] = None
|
||||
"""
|
||||
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
||||
"""
|
||||
|
||||
init_fn: InitFnType = InitFnType.normal
|
||||
"""
|
||||
The weight initialization strategy.
|
||||
"""
|
||||
|
||||
init_std: float = 0.02
|
||||
"""
|
||||
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
||||
as "normal".
|
||||
"""
|
||||
|
||||
init_cutoff_factor: Optional[float] = None
|
||||
"""
|
||||
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
||||
as "normal". Setting this to None means values are not cutoff.
|
||||
"""
|
||||
|
||||
precision: Optional[str] = None
|
||||
"""
|
||||
Precision used to train/evaluate with. You shouldn't set this directly.
|
||||
See :data:`TrainConfig.precision` instead.
|
||||
"""
|
||||
|
||||
@property
|
||||
def effective_n_kv_heads(self) -> int:
|
||||
if self.n_kv_heads is None:
|
||||
if self.multi_query_attention is True:
|
||||
return 1
|
||||
else:
|
||||
return self.n_heads
|
||||
else:
|
||||
if self.multi_query_attention is None:
|
||||
return self.n_kv_heads
|
||||
if self.multi_query_attention:
|
||||
n_kv_heads_should_be = 1
|
||||
else:
|
||||
n_kv_heads_should_be = self.n_heads
|
||||
if self.n_kv_heads == n_kv_heads_should_be:
|
||||
return n_kv_heads_should_be
|
||||
else:
|
||||
raise Exception(
|
||||
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
||||
)
|
||||
|
||||
class ActivationCheckpointingStrategy(StrEnum):
|
||||
whole_layer = "whole_layer"
|
||||
"""
|
||||
Checkpoint every transformer layer.
|
||||
"""
|
||||
|
||||
one_in_two = "one_in_two"
|
||||
"""
|
||||
Checkpoint one in two transformer layers.
|
||||
"""
|
||||
|
||||
one_in_three = "one_in_three"
|
||||
"""
|
||||
Checkpoint one in three transformer layers.
|
||||
"""
|
||||
|
||||
one_in_four = "one_in_four"
|
||||
"""
|
||||
Checkpoint one in four transformer layers.
|
||||
"""
|
||||
|
||||
two_in_three = "two_in_three"
|
||||
"""
|
||||
Checkpoint two out of every three transformer layers.
|
||||
"""
|
||||
|
||||
three_in_four = "three_in_four"
|
||||
"""
|
||||
Checkpoint three out of four of every transformer layers.
|
||||
"""
|
||||
|
||||
four_in_five = "four_in_five"
|
||||
"""
|
||||
Checkpoint four out of five of every transformer layers.
|
||||
"""
|
||||
|
||||
nine_in_ten = "nine_in_ten"
|
||||
"""
|
||||
Checkpoint nine out of ten of every transformer layers.
|
||||
"""
|
||||
|
||||
fine_grained = "fine_grained"
|
||||
"""
|
||||
Focus checkpointing on where it is cheap to recompute and saves most memory.
|
||||
"""
|
||||
|
||||
|
||||
class LLaDAConfig(PretrainedConfig):
|
||||
model_type = "llada"
|
||||
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
||||
|
||||
def __init__(self, use_cache: bool = False, **kwargs):
|
||||
model_config = ModelConfig()
|
||||
all_kwargs = model_config.__dict__
|
||||
all_kwargs.update(kwargs)
|
||||
all_kwargs.update({"use_cache": use_cache})
|
||||
all_kwargs.update(
|
||||
{
|
||||
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
||||
}
|
||||
)
|
||||
super().__init__(**all_kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.n_heads
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.n_layers
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.d_model
|
||||
96
dllm/dllm/pipelines/llada/models/configuration_lladamoe.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""
|
||||
LLaDA MoE configuration
|
||||
"""
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class LLaDAMoEConfig(PretrainedConfig):
|
||||
model_type = "lladamoe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=-1,
|
||||
hidden_size=-1,
|
||||
dense_intermediate_size=-1,
|
||||
expert_intermediate_size=-1,
|
||||
shared_expert_intermediate_size=-1,
|
||||
num_hidden_layers=-1,
|
||||
num_attention_heads=-1,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=False,
|
||||
pad_token_id=1,
|
||||
bos_token_id=None,
|
||||
eos_token_id=50279,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=-1,
|
||||
partial_rotary_factor=-1,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
clip_qkv=None,
|
||||
num_experts_per_tok=-1,
|
||||
num_experts=-1,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.01,
|
||||
norm_topk_prob=None,
|
||||
qk_layernorm=None,
|
||||
moe_layer_freq=[],
|
||||
moe_router_enable_expert_bias=None,
|
||||
moe_router_score_function=None,
|
||||
routed_scaling_factor=1,
|
||||
router_num_group=-2,
|
||||
router_topk_group=-2,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.expert_intermediate_size = expert_intermediate_size
|
||||
self.dense_intermediate_size = dense_intermediate_size
|
||||
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.clip_qkv = clip_qkv
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.qk_layernorm = qk_layernorm
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
|
||||
self.moe_router_score_function = moe_router_score_function
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.router_num_group = router_num_group
|
||||
self.router_topk_group = router_topk_group
|
||||
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
1458
dllm/dllm/pipelines/llada/models/modeling_llada.py
Normal file
1168
dllm/dllm/pipelines/llada/models/modeling_lladamoe.py
Normal file
3
dllm/dllm/pipelines/llada/trainer.py
Normal file
@ -0,0 +1,3 @@
|
||||
from dllm.core.trainers import MDLMTrainer
|
||||
|
||||
LLaDATrainer = MDLMTrainer
|
||||
7
dllm/dllm/pipelines/rnd/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# from dllm.pipelines.rnd import generate, trainer
|
||||
from . import models
|
||||
from .models import RND1LM, RND1Config, RND1GenerationConfig
|
||||
|
||||
# from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
|
||||
# from dllm.pipelines.rnd.models.configuration_rnd import RND1Config
|
||||
from .trainer import RNDTrainer
|
||||
53
dllm/dllm/pipelines/rnd/models/__init__.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Radical Numerics Diffusion (RND1) - Diffusion-based Language Model.
|
||||
"""
|
||||
|
||||
from .configuration_rnd import RND1Config
|
||||
from .modeling_rnd import (
|
||||
RND1LM,
|
||||
RND1Model,
|
||||
RND1PreTrainedModel,
|
||||
RND1Attention,
|
||||
RND1DecoderLayer,
|
||||
RND1SparseMoeBlock,
|
||||
)
|
||||
from .generation_config import RND1GenerationConfig
|
||||
from .generation_utils import RND1GenerationMixin
|
||||
from .sampling import (
|
||||
diffusion_sample,
|
||||
apply_top_k_filtering,
|
||||
apply_top_p_filtering,
|
||||
)
|
||||
from .terminal_visualizer import TerminalVisualizer, SimpleProgressBar
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"RND1Config",
|
||||
"RND1GenerationConfig",
|
||||
"RND1LM",
|
||||
"RND1Model",
|
||||
"RND1PreTrainedModel",
|
||||
"RND1Attention",
|
||||
"RND1DecoderLayer",
|
||||
"RND1SparseMoeBlock",
|
||||
"RND1GenerationMixin",
|
||||
"TerminalVisualizer",
|
||||
"SimpleProgressBar",
|
||||
]
|
||||
|
||||
# Register with HuggingFace Auto classes for local usage
|
||||
try:
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
|
||||
|
||||
AutoConfig.register("rnd1", RND1Config)
|
||||
AutoModel.register(RND1Config, RND1LM)
|
||||
AutoModelForMaskedLM.register(RND1Config, RND1LM)
|
||||
except ImportError:
|
||||
# transformers not available or Auto classes not imported
|
||||
pass
|
||||
124
dllm/dllm/pipelines/rnd/models/configuration_rnd.py
Normal file
@ -0,0 +1,124 @@
|
||||
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 Model Configuration.
|
||||
|
||||
This module defines the configuration class for RND1 models.
|
||||
The default settings are derived from Qwen/Qwen3-30B-A3B and augmented
|
||||
with RND1-specific parameters.
|
||||
"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
# Qwen3-30B-A3B / checkpoint defaults
|
||||
CONFIG_DEFAULTS = {
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"decoder_sparse_step": 1,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 48,
|
||||
"mlp_only_layers": [],
|
||||
"moe_intermediate_size": 768,
|
||||
"norm_topk_prob": True,
|
||||
"num_attention_heads": 32,
|
||||
"num_experts": 128,
|
||||
"num_experts_per_tok": 8,
|
||||
"num_hidden_layers": 48,
|
||||
"num_key_value_heads": 4,
|
||||
"output_router_logits": False,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": False,
|
||||
"rope_theta": 1000000.0,
|
||||
"router_aux_loss_coef": 0.001,
|
||||
"sliding_window": False,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"use_cache": False,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936,
|
||||
}
|
||||
|
||||
|
||||
class RND1Config(PretrainedConfig):
|
||||
"""
|
||||
Configuration class for RND1 models.
|
||||
|
||||
This configuration extends Qwen3MoeConfig with additional parameters
|
||||
specific to the RND1 (Radical Numerics Diffusion v1) architecture.
|
||||
|
||||
Args:
|
||||
moe_backend: Backend for MoE computation ("hf", "vllm", "sglang" or "flashinfer")
|
||||
num_diffusion_steps: Default number of diffusion steps for generation
|
||||
mask_token_id: Token ID used for masking (default: 151669 for Qwen)
|
||||
**kwargs: Additional arguments passed to Qwen3MoeConfig
|
||||
"""
|
||||
|
||||
model_type = "rnd1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_backend: str = "hf",
|
||||
num_diffusion_steps: int = 256,
|
||||
mask_token_id: int = 151669,
|
||||
**kwargs,
|
||||
):
|
||||
# Force non-causal and no caching for RND1
|
||||
kwargs["use_cache"] = False
|
||||
kwargs["is_causal"] = False
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set defaults after pretrained init to prevent overrides
|
||||
self.set_config_defaults()
|
||||
|
||||
# QoL: set attn impl directly from config
|
||||
if "attn_implementation" in kwargs:
|
||||
self._attn_implementation = kwargs["attn_implementation"]
|
||||
|
||||
# RND1-specific parameters
|
||||
self.moe_backend = moe_backend
|
||||
self.num_diffusion_steps = num_diffusion_steps
|
||||
self.mask_token_id = mask_token_id
|
||||
|
||||
# Ensure bidirectional attention and no caching
|
||||
self.is_causal = False
|
||||
self.use_cache = False
|
||||
|
||||
def set_config_defaults(self):
|
||||
"""
|
||||
Ensure model defaults are set according to final training checkpoint
|
||||
|
||||
Qwen3MoeConfig defaults don't match Qwen/Qwen3-30B-A3B settings from which
|
||||
RND1 is derived.
|
||||
"""
|
||||
for k, v in CONFIG_DEFAULTS.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes configuration to dictionary with auto_map for Hub.
|
||||
|
||||
The auto_map ensures that when users load from HuggingFace Hub,
|
||||
the correct custom classes are automatically resolved.
|
||||
"""
|
||||
data = super().to_dict()
|
||||
data.setdefault(
|
||||
"auto_map",
|
||||
{
|
||||
"AutoConfig": "configuration_rnd.RND1Config",
|
||||
"AutoModel": "modeling_rnd.RND1Model",
|
||||
"AutoModelForMaskedLM": "modeling_rnd.RND1LM",
|
||||
},
|
||||
)
|
||||
return data
|
||||
77
dllm/dllm/pipelines/rnd/models/generation_config.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 Generation Configuration.
|
||||
|
||||
This module defines the generation configuration for RND1 models,
|
||||
controlling the diffusion-based generation process.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
|
||||
|
||||
class RND1GenerationConfig(GenerationConfig):
|
||||
"""
|
||||
Configuration class for RND1 generation parameters.
|
||||
|
||||
This class extends the base GenerationConfig to include parameters
|
||||
specific to diffusion-based language generation.
|
||||
|
||||
Args:
|
||||
max_length: Maximum sequence length
|
||||
num_diffusion_steps: Number of denoising steps in the diffusion process
|
||||
mask_token_id: Token ID used for masking during diffusion
|
||||
temperature: Temperature for sampling (higher = more random)
|
||||
top_k: Optional top-k filtering
|
||||
top_p: Optional nucleus (top-p) filtering
|
||||
greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
|
||||
**kwargs: Additional arguments passed to GenerationConfig
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_length: int = 256,
|
||||
num_diffusion_steps: int = 256,
|
||||
mask_token_id: int = 151669,
|
||||
temperature: float = 0.1,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
greedy: bool = False,
|
||||
bos_token_id: int = None,
|
||||
eos_token_id: int = None,
|
||||
pad_token_id: int = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Force no caching for RND generation
|
||||
# kwargs['use_cache'] = False
|
||||
kwargs.pop('use_cache', None)
|
||||
super().__init__(
|
||||
max_length=max_length,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=not greedy,
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# RND-specific parameters
|
||||
self.num_diffusion_steps = num_diffusion_steps
|
||||
self.mask_token_id = mask_token_id
|
||||
self.greedy = greedy
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert configuration to dictionary."""
|
||||
output = super().to_dict()
|
||||
output["num_diffusion_steps"] = self.num_diffusion_steps
|
||||
output["mask_token_id"] = self.mask_token_id
|
||||
output["greedy"] = self.greedy
|
||||
return output
|
||||
187
dllm/dllm/pipelines/rnd/models/generation_utils.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 Generation Utilities.
|
||||
|
||||
This module provides generation utilities and mixins for RND1 models,
|
||||
including the main GenerationMixin class that integrates with HuggingFace.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from transformers import GenerationMixin as HFGenerationMixin
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
|
||||
|
||||
|
||||
class RND1GenerationMixin(HFGenerationMixin):
|
||||
"""
|
||||
Generation mixin for RND1 models.
|
||||
|
||||
This mixin provides generation methods compatible with HuggingFace's
|
||||
generation API while using RND1's diffusion-based sampling internally.
|
||||
"""
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.LongTensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
# RND1-specific parameters
|
||||
prefix_ids: Optional[torch.LongTensor] = None,
|
||||
suffix_ids: Optional[torch.LongTensor] = None,
|
||||
infill_length: Optional[int] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
**kwargs, # Accept all kwargs to be compatible with pipelines
|
||||
) -> Union[torch.LongTensor, Dict[str, Any]]:
|
||||
"""
|
||||
Generate text using RND1's diffusion-based sampling.
|
||||
|
||||
Follows HuggingFace's standard generate API, using diffusion sampling
|
||||
internally. Supports both standard generation and infilling.
|
||||
|
||||
Args:
|
||||
inputs: Input token IDs to use as prefix (standard HF parameter)
|
||||
generation_config: Generation configuration object
|
||||
prefix_ids: Alternative to inputs for infilling tasks
|
||||
suffix_ids: Optional suffix for infilling tasks
|
||||
infill_length: Length of infill region (for infilling)
|
||||
return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
|
||||
**kwargs: Additional arguments (accepted for compatibility)
|
||||
|
||||
Returns:
|
||||
Generated token IDs or GenerateDecoderOnlyOutput
|
||||
"""
|
||||
if generation_config is not None:
|
||||
gen_config = generation_config
|
||||
model_kwargs = kwargs.copy()
|
||||
else:
|
||||
# Only prepare config from kwargs if no config was provided
|
||||
gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
if inputs is not None:
|
||||
prefix_ids = inputs.to(device)
|
||||
elif prefix_ids is not None:
|
||||
prefix_ids = prefix_ids.to(device)
|
||||
else:
|
||||
prefix_ids = None
|
||||
|
||||
if suffix_ids is not None:
|
||||
suffix_ids = suffix_ids.to(device)
|
||||
|
||||
eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
|
||||
eos_token_id = None if eos_token_id == -1 else eos_token_id
|
||||
pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
|
||||
bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
|
||||
mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
|
||||
|
||||
if infill_length is not None and prefix_ids is not None:
|
||||
# Infilling mode: use specified infill_length
|
||||
prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
|
||||
suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
|
||||
seq_len = prefix_len + infill_length + suffix_len
|
||||
else:
|
||||
# Standard generation mode
|
||||
if prefix_ids is not None:
|
||||
prefix_len = prefix_ids.shape[1]
|
||||
if gen_config.max_new_tokens is not None:
|
||||
seq_len = prefix_len + gen_config.max_new_tokens
|
||||
else:
|
||||
seq_len = gen_config.max_length or self.config.max_position_embeddings
|
||||
else:
|
||||
seq_len = gen_config.max_length or self.config.max_position_embeddings
|
||||
|
||||
num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
|
||||
getattr(self.config, "num_diffusion_steps", 256))
|
||||
|
||||
temperature = float(getattr(gen_config, "temperature", 1.0))
|
||||
top_k = getattr(gen_config, "top_k", None)
|
||||
top_p = getattr(gen_config, "top_p", None)
|
||||
|
||||
greedy = getattr(gen_config, "greedy",
|
||||
not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
sequences = diffusion_sample(
|
||||
model=self,
|
||||
seq_len=seq_len,
|
||||
num_steps=num_diffusion_steps,
|
||||
mask_token_id=mask_token_id,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
greedy=greedy,
|
||||
prefix_ids=prefix_ids,
|
||||
suffix_ids=suffix_ids,
|
||||
infill_length=infill_length,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
device=device,
|
||||
visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
|
||||
)
|
||||
|
||||
if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
|
||||
from transformers.generation.utils import GenerateDecoderOnlyOutput
|
||||
return GenerateDecoderOnlyOutput(sequences=sequences)
|
||||
|
||||
return sequences
|
||||
|
||||
def generate_with_visualization(
|
||||
self,
|
||||
tokenizer,
|
||||
inputs: Optional[torch.LongTensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
suffix_ids: Optional[torch.LongTensor] = None,
|
||||
infill_length: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generate with live visualization (for demos).
|
||||
|
||||
This method requires a tokenizer to display the generation process.
|
||||
For production use, prefer `generate()`.
|
||||
|
||||
Args:
|
||||
tokenizer: Tokenizer for decoding tokens to text
|
||||
inputs: Input token IDs to use as prefix
|
||||
generation_config: Generation configuration object
|
||||
suffix_ids: Optional suffix token IDs
|
||||
infill_length: Length of infill region
|
||||
**kwargs: Additional arguments for backward compatibility
|
||||
|
||||
Returns:
|
||||
Generated token IDs as LongTensor
|
||||
"""
|
||||
from .terminal_visualizer import TerminalVisualizer
|
||||
visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
|
||||
|
||||
return self.generate(
|
||||
inputs=inputs,
|
||||
generation_config=generation_config,
|
||||
suffix_ids=suffix_ids,
|
||||
infill_length=infill_length,
|
||||
visualizer=visualizer,
|
||||
return_dict_in_generate=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare inputs for generation (required by HuggingFace).
|
||||
|
||||
For RND1, we don't use the standard autoregressive generation,
|
||||
so this just returns the input_ids.
|
||||
"""
|
||||
return {"input_ids": input_ids}
|
||||
653
dllm/dllm/pipelines/rnd/models/modeling_rnd.py
Normal file
@ -0,0 +1,653 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 model implementation.
|
||||
|
||||
This module implements the RND1 architecture with bidirectional attention for
|
||||
diffusion-based language modeling. Includes support for Mixture of Experts (MoE)
|
||||
with multiple backend options (HF, vLLM, SGLang, FlashInfer).
|
||||
|
||||
Based on the Qwen3Moe architecture:
|
||||
https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.utils import logging
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import (
|
||||
MoeModelOutputWithPast,
|
||||
MaskedLMOutput,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from .configuration_rnd import RND1Config
|
||||
from .generation_utils import RND1GenerationMixin
|
||||
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||
Qwen3MoeRMSNorm,
|
||||
Qwen3MoeRotaryEmbedding,
|
||||
Qwen3MoeMLP,
|
||||
apply_rotary_pos_emb
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts as fused_experts_vllm, fused_topk as fused_topk_vllm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm as VLLMRMSNorm
|
||||
except Exception:
|
||||
fused_experts_vllm = None
|
||||
fused_topk_vllm = None
|
||||
|
||||
try:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe
|
||||
# from sglang.srt.layers.layernorm import RMSNorm as SGLangRMSNorm # TODO: buggy atm
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
except Exception:
|
||||
sglang_fused_moe = None
|
||||
StandardTopKOutput = None
|
||||
|
||||
|
||||
try:
|
||||
import flashinfer.fused_moe as fused_moe
|
||||
## TODO: below needs flashinfer>=0.4.0, but has some bug atm
|
||||
# from flashinfer.norm import rmsnorm as flashinfer_rmsnorm
|
||||
# class FlashInferRMSNorm(Qwen3MoeRMSNorm):
|
||||
# """Wrapper around FlashInfer RMSNorm to match Qwen3MoeRMSNorm interface"""
|
||||
# def forward(self, hidden_states):
|
||||
# return flashinfer_rmsnorm(hidden_states, self.weight, self.variance_epsilon)
|
||||
|
||||
except Exception:
|
||||
fused_moe = None
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""Expand key/value heads to match query heads for grouped-query attention."""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class RND1Attention(nn.Module):
|
||||
"""RND1 attention layer with bidirectional attention for diffusion modeling."""
|
||||
|
||||
def __init__(self, config: RND1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
|
||||
|
||||
if config.moe_backend == "vllm":
|
||||
RMSNormClass = VLLMRMSNorm
|
||||
else:
|
||||
RMSNormClass = Qwen3MoeRMSNorm
|
||||
self.q_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
self.sliding_window = getattr(config, "sliding_window", None)
|
||||
|
||||
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
dual_cache: Optional[bool] = False,
|
||||
replace_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
use_sdpa = (getattr(self.config, "_attn_implementation", "eager") == "sdpa")
|
||||
|
||||
if use_sdpa:
|
||||
if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
|
||||
if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]:
|
||||
attention_mask = attention_mask.to(dtype=query_states.dtype)
|
||||
|
||||
assert not self.is_causal, f"Attention layer {self.layer_idx} is causal"
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states,
|
||||
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
attn_out = attn_out.transpose(1, 2).contiguous()
|
||||
attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_out = self.o_proj(attn_out)
|
||||
return attn_out, None
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
# TODO: modify this to boolean masks
|
||||
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
|
||||
attn_out = torch.matmul(attn_weights, value_states)
|
||||
attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1)
|
||||
attn_out = self.o_proj(attn_out)
|
||||
|
||||
return attn_out, None
|
||||
|
||||
|
||||
class RND1DecoderLayer(nn.Module):
|
||||
"""RND1 decoder layer with bidirectional attention for diffusion language modeling."""
|
||||
|
||||
def __init__(self, config: RND1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = RND1Attention(config, layer_idx)
|
||||
self.mlp = RND1SparseMoeBlock(config)
|
||||
if config.moe_backend == "vllm":
|
||||
RMSNormClass = VLLMRMSNorm
|
||||
else:
|
||||
RMSNormClass = Qwen3MoeRMSNorm
|
||||
self.input_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
replace_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_out, attn_weights = self.self_attn(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
position_embeddings=position_embeddings,
|
||||
replace_position=replace_position,
|
||||
)
|
||||
hidden_states = residual + attn_out
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
ff_out = self.mlp(hidden_states)
|
||||
if isinstance(ff_out, tuple):
|
||||
ff_out = ff_out[0]
|
||||
hidden_states = residual + ff_out
|
||||
|
||||
return hidden_states, attn_weights
|
||||
|
||||
|
||||
class RND1SparseMoeBlock(nn.Module):
|
||||
"""RND1 Sparse MoE block with multiple backend support (HF, vLLM, SGLang, FlashInfer)."""
|
||||
|
||||
def __init__(self, config: RND1Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.backend = getattr(config, "moe_backend", "hf")
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size)
|
||||
|
||||
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
|
||||
self.experts = nn.ModuleList(
|
||||
[Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
|
||||
)
|
||||
|
||||
# Cached weight tensors for optimized backends
|
||||
self._w1 = None
|
||||
self._w2 = None
|
||||
if self.backend == "sglang":
|
||||
if sglang_fused_moe is None or StandardTopKOutput is None:
|
||||
raise RuntimeError("sglang is not available, cannot use sglang backend")
|
||||
elif self.backend == "flashinfer":
|
||||
if fused_moe is None:
|
||||
raise RuntimeError("flashinfer is not available, cannot use flashinfer backend")
|
||||
elif self.backend == "vllm":
|
||||
if fused_experts_vllm is None or fused_topk_vllm is None:
|
||||
raise RuntimeError("vllm is not available, cannot use vllm backend")
|
||||
|
||||
@torch.no_grad()
|
||||
def _initialize_weights(
|
||||
self,
|
||||
free_experts: bool = True,
|
||||
mode: str = "vllm",
|
||||
) -> None:
|
||||
logger.info(f"Initializing weights for {mode} backend")
|
||||
# Stack directly on device where weights already reside (loaded by HF)
|
||||
gate_list: List[torch.Tensor] = []
|
||||
up_list: List[torch.Tensor] = []
|
||||
down_list: List[torch.Tensor] = []
|
||||
|
||||
# Collect weight references without any device moves
|
||||
for expert in self.experts:
|
||||
gate_list.append(expert.gate_proj.weight.data)
|
||||
up_list.append(expert.up_proj.weight.data)
|
||||
down_list.append(expert.down_proj.weight.data)
|
||||
|
||||
gate_w_stacked = torch.stack(gate_list, dim=0).contiguous()
|
||||
up_w_stacked = torch.stack(up_list, dim=0).contiguous()
|
||||
down_w_stacked = torch.stack(down_list, dim=0).contiguous()
|
||||
|
||||
if mode == "flashinfer":
|
||||
w1 = torch.cat([up_w_stacked, gate_w_stacked], dim=1) # FlashInfer expects [up; gate] ordering
|
||||
else:
|
||||
w1 = torch.cat([gate_w_stacked, up_w_stacked], dim=1)
|
||||
w2 = down_w_stacked
|
||||
self._w1 = w1
|
||||
self._w2 = w2
|
||||
|
||||
|
||||
if free_experts:
|
||||
# Free per-expert modules to reclaim memory
|
||||
logger.info(f"Freeing experts for {mode} backend")
|
||||
del self.experts
|
||||
self.experts = None
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass with expert routing and computation."""
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
x = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Expert routing
|
||||
router_logits = self.gate(x)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.backend == "vllm":
|
||||
routing_weights, selected_experts, *_ = fused_topk_vllm(
|
||||
hidden_states=x,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.norm_topk_prob,
|
||||
)
|
||||
else:
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
# if self.backend == "hf":
|
||||
# final_hidden_states = torch.zeros(
|
||||
# (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
# )
|
||||
|
||||
# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
# expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
# for expert_idx in expert_hit:
|
||||
# expert_layer = self.experts[expert_idx]
|
||||
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
# current_state = x[top_x]
|
||||
# current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
# out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
# return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
if self.backend == "hf":
|
||||
# Accumulate buffer: [B*S, H]
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# expert_mask: [E, top_k, tokens]
|
||||
expert_mask = torch.nn.functional.one_hot(
|
||||
selected_experts, num_classes=self.num_experts
|
||||
).permute(2, 1, 0).contiguous()
|
||||
|
||||
# 顺序遍历所有 experts;即使本 rank 没命中也要进入 forward,避免 ZeRO-3 控制流分歧
|
||||
for e in range(self.num_experts):
|
||||
expert_layer = self.experts[int(e)]
|
||||
|
||||
# 取出该 expert 命中的 token 索引
|
||||
idx, top_x = torch.where(expert_mask[e]) # idx∈[0, top_k), shapes: [n_tok_e]
|
||||
current_state = x[top_x] # [n_tok_e, H],n_tok_e 可能为 0
|
||||
# if top_x.numel() == 0:
|
||||
# print("0")
|
||||
|
||||
# 空批照样前向;大多数 Linear/MLP 对 0 行输入是 no-op,但会对齐 ZeRO-3 的参数路径
|
||||
expert_out = expert_layer(current_state) # [n_tok_e, H]
|
||||
|
||||
# 路由权重并加权
|
||||
w = routing_weights[top_x, idx] # [n_tok_e]
|
||||
expert_out = expert_out * w.unsqueeze(-1) # [n_tok_e, H]
|
||||
|
||||
# 累加回全局缓冲;当 n_tok_e=0 时这是合法的空操作
|
||||
final_hidden_states.index_add_(0, top_x, expert_out.to(hidden_states.dtype))
|
||||
|
||||
out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
|
||||
elif self.backend == "flashinfer":
|
||||
# if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None:
|
||||
# self._initialize_flashinfer_weights()
|
||||
if self._w1 is None or self._w2 is None:
|
||||
self._initialize_weights(mode="flashinfer")
|
||||
|
||||
result = fused_moe.cutlass_fused_moe(
|
||||
input=x,
|
||||
token_selected_experts=selected_experts.to(torch.int),
|
||||
token_final_scales=routing_weights.to(torch.float32),
|
||||
fc1_expert_weights=self._w1,
|
||||
fc2_expert_weights=self._w2,
|
||||
output_dtype=x.dtype,
|
||||
quant_scales=None,
|
||||
)
|
||||
if isinstance(result, (list, tuple)):
|
||||
out_flat = result[0]
|
||||
else:
|
||||
out_flat = result
|
||||
out = out_flat.view(batch_size, sequence_length, hidden_dim)
|
||||
return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
|
||||
elif self.backend == "sglang":
|
||||
if self._w1 is None or self._w2 is None:
|
||||
self._initialize_weights(mode="sglang")
|
||||
|
||||
topk_output = StandardTopKOutput(
|
||||
topk_weights=routing_weights,
|
||||
topk_ids=selected_experts,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
out_flat = sglang_fused_moe(
|
||||
hidden_states=x,
|
||||
w1=self._w1,
|
||||
w2=self._w2,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
out = out_flat.view(batch_size, sequence_length, hidden_dim)
|
||||
return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
|
||||
elif self.backend == "vllm":
|
||||
if self._w1 is None or self._w2 is None:
|
||||
self._initialize_weights()
|
||||
|
||||
out_flat = fused_experts_vllm(
|
||||
x,
|
||||
self._w1,
|
||||
self._w2,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
)
|
||||
out = out_flat.view(batch_size, sequence_length, hidden_dim)
|
||||
return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {self.backend}")
|
||||
|
||||
|
||||
class RND1PreTrainedModel(PreTrainedModel):
|
||||
"""Base class for RND1 models with weight initialization and loading support."""
|
||||
config_class = RND1Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["RND1DecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize weights using normal distribution."""
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
*model_args,
|
||||
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
force_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
use_safetensors: Optional[bool] = None,
|
||||
weights_only: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load pretrained model with generation config."""
|
||||
_model = super().from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
*model_args,
|
||||
config=config,
|
||||
cache_dir=cache_dir,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
use_safetensors=use_safetensors,
|
||||
weights_only=weights_only,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
resume_download = kwargs.get("resume_download", None)
|
||||
proxies = kwargs.get("proxies", None)
|
||||
subfolder = kwargs.get("subfolder", "")
|
||||
from_auto_class = kwargs.get("_from_auto", False)
|
||||
from_pipeline = kwargs.get("_from_pipeline", None)
|
||||
|
||||
_model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
)
|
||||
|
||||
# If configured to use a fused backend, pack fused tensors once after load
|
||||
try:
|
||||
cfg = getattr(_model, "config", None)
|
||||
backend = getattr(cfg, "moe_backend", "hf") if cfg is not None else "hf"
|
||||
if backend in ("sglang", "vllm"):
|
||||
# Walk decoder layers and initialize fused weights
|
||||
model_core = getattr(_model, "model", _model)
|
||||
layers = getattr(model_core, "layers", None)
|
||||
if isinstance(layers, nn.ModuleList):
|
||||
for layer in layers:
|
||||
mlp = getattr(layer, "mlp", None)
|
||||
if hasattr(mlp, "_initialize_weights"):
|
||||
mlp._initialize_weights(
|
||||
free_experts=True,
|
||||
mode=backend,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.warning(f"Backend {backend} weight processing skipped: {_e}")
|
||||
|
||||
return _model
|
||||
|
||||
|
||||
class RND1Model(RND1PreTrainedModel):
|
||||
"""RND1 transformer model with bidirectional attention for diffusion language modeling."""
|
||||
|
||||
def __init__(self, config: RND1Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
|
||||
if config.moe_backend == "vllm":
|
||||
RMSNormClass = VLLMRMSNorm
|
||||
else:
|
||||
RMSNormClass = Qwen3MoeRMSNorm
|
||||
self.norm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> MoeModelOutputWithPast:
|
||||
"""Forward pass through the RND1 model."""
|
||||
|
||||
if (input_ids is None) == (inputs_embeds is None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
|
||||
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
# shape: (batch_size, 1, 1, seq_len)
|
||||
attention_mask = attention_mask.to(dtype=torch.float)[:, None, None, :]
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
||||
|
||||
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states, _ = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
router_logits=None,
|
||||
)
|
||||
|
||||
|
||||
class RND1LM(RND1PreTrainedModel, RND1GenerationMixin):
|
||||
"""Radical Numerics Diffusion Language Model with bidirectional attention."""
|
||||
|
||||
def __init__(self, config: RND1Config):
|
||||
super().__init__(config)
|
||||
self.model = RND1Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
"""Get the input embeddings layer."""
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
"""Set the input embeddings layer."""
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
"""Get the output embeddings layer (lm_head)."""
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
"""Set the output embeddings layer (lm_head)."""
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@classmethod
|
||||
def can_generate(cls) -> bool:
|
||||
"""Indicates this model can generate text."""
|
||||
return True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> MaskedLMOutput:
|
||||
"""Forward pass with optional loss computation."""
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
logits = self.lm_head(outputs.last_hidden_state)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
)
|
||||
260
dllm/dllm/pipelines/rnd/models/sampling.py
Normal file
@ -0,0 +1,260 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 sampling module for masked diffusion generation.
|
||||
|
||||
This module implements entropy-based token selection for iterative denoising
|
||||
in diffusion language models. Supports both greedy and stochastic sampling
|
||||
with optional prefix/suffix constraints and infilling.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k filtering to logits: with non-top-k values set to -inf
|
||||
"""
|
||||
top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
|
||||
filtered_logits = torch.full_like(logits, float('-inf'))
|
||||
filtered_logits.scatter_(-1, top_k_indices, top_k_values)
|
||||
return filtered_logits
|
||||
|
||||
|
||||
def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
|
||||
"""
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above threshold
|
||||
sorted_indices_to_remove = cumulative_probs > p
|
||||
sorted_indices_to_remove[..., 0] = False # Keep at least one token
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
|
||||
return logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def diffusion_sample(
|
||||
model: nn.Module,
|
||||
seq_len: int = 256,
|
||||
num_steps: int = 256,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: float = 1.0,
|
||||
greedy: bool = True,
|
||||
mask_token_id: int = 151669,
|
||||
prefix_ids: Optional[torch.LongTensor] = None,
|
||||
suffix_ids: Optional[torch.LongTensor] = None,
|
||||
infill_length: Optional[int] = None,
|
||||
eos_token_id: int = 151645,
|
||||
pad_token_id: Optional[int] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
visualizer: Optional[object] = None,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Perform masked diffusion sampling with entropy-based token selection.
|
||||
|
||||
Args:
|
||||
model: The RND1 language model
|
||||
seq_len: Target sequence length
|
||||
num_steps: Number of denoising steps
|
||||
top_k: Optional top-k filtering for sampling (None = no filtering)
|
||||
top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering)
|
||||
When both top_k and top_p are set, top_k is applied first, then top_p
|
||||
temperature: Temperature for sampling (higher = more random, lower = more deterministic)
|
||||
Values close to 0 are clamped to 1e-8 to avoid division by zero
|
||||
greedy: Whether to use greedy sampling (True) or stochastic (False)
|
||||
mask_token_id: Token ID for masked positions (default: 151669)
|
||||
prefix_ids: Optional prefix token IDs to preserve
|
||||
suffix_ids: Optional suffix token IDs to preserve
|
||||
infill_length: Length of infill region between prefix/suffix
|
||||
eos_token_id: End of sequence token ID (default: 151645)
|
||||
pad_token_id: Padding token ID (default: None, uses 0 if needed)
|
||||
bos_token_id: Beginning of sequence token ID (default: None)
|
||||
device: Device for computation (None = infer from model)
|
||||
visualizer: Optional visualizer for live visualization
|
||||
|
||||
Returns:
|
||||
Generated token IDs as LongTensor
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
else:
|
||||
device = torch.device(device)
|
||||
|
||||
if pad_token_id is None:
|
||||
pad_token_id = 0
|
||||
|
||||
# Build initial masked sequence
|
||||
# When prefix_ids is provided, we create a sequence of length seq_len where:
|
||||
# - The prefix occupies the first pre_len positions
|
||||
# - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated
|
||||
if prefix_ids is not None or suffix_ids is not None:
|
||||
if prefix_ids is not None:
|
||||
prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device)
|
||||
pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0
|
||||
else:
|
||||
pre_len = 0
|
||||
|
||||
if suffix_ids is not None:
|
||||
suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device)
|
||||
suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0
|
||||
else:
|
||||
suf_len = 0
|
||||
|
||||
reserved = (1 if eos_token_id is not None else 0)
|
||||
used = pre_len + suf_len + reserved
|
||||
|
||||
if used > seq_len:
|
||||
raise ValueError(
|
||||
f"Combined length of prefix ({pre_len}), suffix ({suf_len}), "
|
||||
f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). "
|
||||
f"Please increase seq_len or reduce input lengths."
|
||||
)
|
||||
elif used == seq_len:
|
||||
raise ValueError(
|
||||
f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) "
|
||||
f"+ special tokens ({reserved}) = seq_len ({seq_len}). "
|
||||
f"Need at least 1 position for generation."
|
||||
)
|
||||
|
||||
infill_length = min(infill_length or (seq_len - used), seq_len - used)
|
||||
|
||||
x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device)
|
||||
pos = 0
|
||||
# if bos_token_id is not None:
|
||||
# x[0, pos] = bos_token_id; pos += 1
|
||||
if eos_token_id is not None:
|
||||
x[0, -1] = eos_token_id
|
||||
if pre_len > 0:
|
||||
x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len]
|
||||
pos += pre_len
|
||||
fill_start, fill_end = pos, pos + infill_length
|
||||
x[0, fill_start:fill_end] = mask_token_id
|
||||
# print(fill_start, fill_end, seq_len, used, x[0, -1])
|
||||
pos = fill_end
|
||||
if suf_len > 0:
|
||||
x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len]
|
||||
pos += suf_len
|
||||
|
||||
init_maskable = torch.zeros_like(x, dtype=torch.bool)
|
||||
init_maskable[0, fill_start:fill_end] = True
|
||||
else:
|
||||
x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
|
||||
if bos_token_id is not None:
|
||||
x[0, 0] = bos_token_id
|
||||
if eos_token_id is not None:
|
||||
x[0, -1] = eos_token_id
|
||||
init_maskable = x.eq(mask_token_id)
|
||||
|
||||
if bos_token_id is not None:
|
||||
init_maskable[:, 0] = False
|
||||
if eos_token_id is not None:
|
||||
init_maskable &= x.ne(eos_token_id)
|
||||
init_maskable &= x.ne(pad_token_id)
|
||||
|
||||
maskable = init_maskable.clone()
|
||||
xt = x.clone()
|
||||
|
||||
if visualizer:
|
||||
visualizer.start_visualization(xt, maskable, num_steps)
|
||||
|
||||
def forward_scores(tokens):
|
||||
"""Compute predictions and entropy scores for next tokens."""
|
||||
# Try with input_ids parameter first (standard HF models)
|
||||
try:
|
||||
model_output = model(input_ids=tokens)
|
||||
except TypeError:
|
||||
# Fall back to positional argument
|
||||
model_output = model(tokens)
|
||||
|
||||
# Apply temperature scaling (with safety for near-zero temperature)
|
||||
safe_temperature = max(temperature, 1e-8) # Prevent division by zero
|
||||
logits = model_output.logits / safe_temperature
|
||||
|
||||
# Apply filtering strategies
|
||||
# Note: When both top_k and top_p are provided, they are applied sequentially:
|
||||
# First top_k filters to k tokens, then top_p filters from those k tokens
|
||||
if top_k is not None and top_k > 0:
|
||||
logits = apply_top_k_filtering(logits, top_k)
|
||||
|
||||
if top_p is not None and 0 < top_p < 1.0:
|
||||
logits = apply_top_p_filtering(logits, top_p)
|
||||
|
||||
# Convert to log probabilities
|
||||
logp = torch.log_softmax(logits, dim=-1)
|
||||
|
||||
# Greedy or stochastic sampling
|
||||
if greedy:
|
||||
pred_next = logp.argmax(-1)
|
||||
else:
|
||||
pred_next = torch.distributions.Categorical(logits=logp).sample()
|
||||
|
||||
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
p = logp.exp()
|
||||
ent_next = -(p * logp).sum(-1)
|
||||
|
||||
# Shift predictions: pos i predicts token i+1
|
||||
pred_i = tokens.clone()
|
||||
conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min)
|
||||
ent_i = torch.zeros_like(ent_next)
|
||||
|
||||
pred_i[:, 1:] = pred_next[:, :-1]
|
||||
conf_i[:, 1:] = conf_next[:, :-1]
|
||||
ent_i[:, 1:] = ent_next[:, :-1]
|
||||
|
||||
return pred_i, conf_i, ent_i
|
||||
|
||||
pred_i, conf_i, ent_i = forward_scores(xt)
|
||||
total_masked = init_maskable.sum(1, keepdim=True)
|
||||
finf = torch.finfo(conf_i.dtype)
|
||||
|
||||
for step in range(num_steps - 1, 0, -1):
|
||||
rate = step / num_steps
|
||||
cutoff_len = (total_masked * rate).long().clamp(min=0)
|
||||
|
||||
# Choose HIGH-entropy tokens to keep masked
|
||||
sel_scores = ent_i.masked_fill(~maskable, -finf.max)
|
||||
B, L = sel_scores.shape
|
||||
k_max = cutoff_len.max().item()
|
||||
if k_max > 0:
|
||||
sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True)
|
||||
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
|
||||
for b in range(B):
|
||||
k_b = int(cutoff_len[b].item())
|
||||
if k_b > 0:
|
||||
keep_mask[b, idx[b, :k_b]] = True
|
||||
else:
|
||||
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
|
||||
|
||||
to_unmask = maskable & ~keep_mask
|
||||
if to_unmask.any():
|
||||
xt[to_unmask] = pred_i[to_unmask]
|
||||
maskable[to_unmask] = False
|
||||
|
||||
if visualizer:
|
||||
visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i)
|
||||
|
||||
if maskable.any():
|
||||
pred_i, conf_i, ent_i = forward_scores(xt)
|
||||
|
||||
if maskable.any():
|
||||
xt[maskable] = pred_i[maskable]
|
||||
|
||||
if visualizer:
|
||||
visualizer.stop_visualization()
|
||||
|
||||
return xt
|
||||
251
dllm/dllm/pipelines/rnd/models/terminal_visualizer.py
Normal file
@ -0,0 +1,251 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Terminal visualization for RND1 generation.
|
||||
|
||||
This module provides real-time visualization of the diffusion denoising process,
|
||||
showing token evolution and generation progress in the terminal using rich
|
||||
formatting when available.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.text import Text
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||
from rich.layout import Layout
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
|
||||
|
||||
class TerminalVisualizer:
|
||||
"""
|
||||
Rich-based visualization for diffusion process with live updates.
|
||||
|
||||
Provides real-time visualization of the token denoising process during
|
||||
diffusion-based language generation, with colored highlighting of masked
|
||||
positions and progress tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, show_visualization: bool = True):
|
||||
"""
|
||||
Initialize the terminal visualizer.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer for decoding tokens to text
|
||||
show_visualization: Whether to show visualization (requires rich)
|
||||
"""
|
||||
self.tokenizer = tokenizer
|
||||
self.show_visualization = show_visualization and RICH_AVAILABLE
|
||||
if not RICH_AVAILABLE and show_visualization:
|
||||
print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.")
|
||||
self.show_visualization = False
|
||||
|
||||
if self.show_visualization:
|
||||
self.console = Console()
|
||||
self.live = None
|
||||
self.progress = None
|
||||
self.layout = None
|
||||
else:
|
||||
self.pbar = None
|
||||
|
||||
self.current_tokens = None
|
||||
self.mask_positions = None
|
||||
self.total_steps = 0
|
||||
self.current_step = 0
|
||||
|
||||
def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int):
|
||||
"""
|
||||
Start the visualization.
|
||||
|
||||
Args:
|
||||
initial_tokens: Initial token IDs (possibly masked)
|
||||
mask_positions: Boolean mask indicating which positions are masked
|
||||
total_steps: Total number of diffusion steps
|
||||
"""
|
||||
if not self.show_visualization:
|
||||
self.pbar = tqdm(total=total_steps, desc="Diffusion")
|
||||
return
|
||||
|
||||
self.current_tokens = initial_tokens.clone()
|
||||
self.mask_positions = mask_positions
|
||||
self.total_steps = total_steps
|
||||
self.current_step = 0
|
||||
|
||||
self.layout = Layout()
|
||||
self.layout.split_column(
|
||||
Layout(name="header", size=3),
|
||||
Layout(name="text", ratio=1),
|
||||
Layout(name="progress", size=3)
|
||||
)
|
||||
|
||||
self.progress = Progress(
|
||||
TextColumn("[bold blue]Diffusion"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TextColumn("•"),
|
||||
TextColumn("[cyan]Masks: {task.fields[masks]}"),
|
||||
TimeRemainingColumn(),
|
||||
)
|
||||
self.progress_task = self.progress.add_task(
|
||||
"Generating",
|
||||
total=total_steps,
|
||||
masks=mask_positions.sum().item()
|
||||
)
|
||||
|
||||
self.live = Live(self.layout, console=self.console, refresh_per_second=4)
|
||||
self.live.start()
|
||||
self._update_display()
|
||||
|
||||
def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int,
|
||||
entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None):
|
||||
"""
|
||||
Update visualization for current step.
|
||||
|
||||
Args:
|
||||
tokens: Current token IDs
|
||||
maskable: Boolean mask of remaining masked positions
|
||||
step: Current step number
|
||||
entropy: Optional entropy scores for each position
|
||||
confidence: Optional confidence scores for each position
|
||||
"""
|
||||
if not self.show_visualization:
|
||||
if self.pbar:
|
||||
self.pbar.update(1)
|
||||
masks = maskable.sum().item() if maskable is not None else 0
|
||||
self.pbar.set_postfix({'masks': masks})
|
||||
return
|
||||
|
||||
self.current_tokens = tokens.clone()
|
||||
self.mask_positions = maskable
|
||||
self.current_step = step
|
||||
|
||||
masks_remaining = maskable.sum().item() if maskable is not None else 0
|
||||
self.progress.update(
|
||||
self.progress_task,
|
||||
advance=1,
|
||||
masks=masks_remaining
|
||||
)
|
||||
|
||||
self._update_display()
|
||||
|
||||
def _update_display(self):
|
||||
"""Update the live display."""
|
||||
if not self.live:
|
||||
return
|
||||
|
||||
header = Text("RND1-Base Generation", style="bold magenta", justify="center")
|
||||
self.layout["header"].update(Panel(header, border_style="bright_blue"))
|
||||
|
||||
text_display = self._format_text_with_masks()
|
||||
self.layout["text"].update(
|
||||
Panel(
|
||||
text_display,
|
||||
title="[bold]Generated Text",
|
||||
subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]",
|
||||
border_style="cyan"
|
||||
)
|
||||
)
|
||||
|
||||
self.layout["progress"].update(Panel(self.progress))
|
||||
|
||||
def _format_text_with_masks(self) -> Text:
|
||||
"""
|
||||
Format text with colored masks.
|
||||
|
||||
Returns:
|
||||
Rich Text object with formatted tokens
|
||||
"""
|
||||
text = Text()
|
||||
|
||||
if self.current_tokens is None:
|
||||
return text
|
||||
|
||||
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
|
||||
mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
if mask_flags is not None and i < len(mask_flags) and mask_flags[i]:
|
||||
# Alternate colors for visual effect
|
||||
text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red")
|
||||
else:
|
||||
try:
|
||||
token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
|
||||
# Skip special tokens in display
|
||||
if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]:
|
||||
# Color based on position
|
||||
text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan")
|
||||
except:
|
||||
continue
|
||||
|
||||
return text
|
||||
|
||||
def stop_visualization(self):
|
||||
"""Stop the visualization and display final result."""
|
||||
if not self.show_visualization:
|
||||
if self.pbar:
|
||||
self.pbar.close()
|
||||
print("\n✨ Generation complete!\n")
|
||||
return
|
||||
|
||||
if self.live:
|
||||
self.live.stop()
|
||||
|
||||
self.console.print("\n[bold green]✨ Generation complete![/bold green]\n")
|
||||
|
||||
# Display final text
|
||||
if self.current_tokens is not None:
|
||||
try:
|
||||
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
|
||||
final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
|
||||
self.console.print(Panel(
|
||||
final_text,
|
||||
title="[bold]Final Generated Text",
|
||||
border_style="green",
|
||||
padding=(1, 2)
|
||||
))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class SimpleProgressBar:
|
||||
"""
|
||||
Simple progress bar fallback when rich is not available.
|
||||
|
||||
Provides basic progress tracking using tqdm when the rich library
|
||||
is not installed.
|
||||
"""
|
||||
|
||||
def __init__(self, total_steps: int):
|
||||
"""
|
||||
Initialize simple progress bar.
|
||||
|
||||
Args:
|
||||
total_steps: Total number of steps
|
||||
"""
|
||||
self.pbar = tqdm(total=total_steps, desc="Diffusion")
|
||||
|
||||
def update(self, masks_remaining: int = 0):
|
||||
"""
|
||||
Update progress bar.
|
||||
|
||||
Args:
|
||||
masks_remaining: Number of masks still remaining
|
||||
"""
|
||||
self.pbar.update(1)
|
||||
self.pbar.set_postfix({'masks': masks_remaining})
|
||||
|
||||
def close(self):
|
||||
"""Close the progress bar."""
|
||||
self.pbar.close()
|
||||
print("\n✨ Generation complete!\n")
|
||||
23
dllm/dllm/pipelines/rnd/trainer.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from dllm.core.trainers import MDLMTrainer
|
||||
|
||||
|
||||
class RNDTrainer(MDLMTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _preprocess_inputs(self, inputs):
|
||||
labels = inputs["labels"]
|
||||
assert (labels[:, 0] == -100).all()
|
||||
|
||||
def _postprocess_outputs(self, outputs):
|
||||
logits = outputs.logits
|
||||
outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
242
dllm/dllm/tools/chat.py
Normal file
@ -0,0 +1,242 @@
|
||||
import shutil
|
||||
from typing import List, Literal
|
||||
|
||||
import textwrap
|
||||
|
||||
import dllm
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Utility helpers
|
||||
# ============================================================
|
||||
|
||||
try:
|
||||
L = shutil.get_terminal_size().columns
|
||||
if not isinstance(L, int) or L <= 0:
|
||||
L = 120
|
||||
except Exception:
|
||||
L = 120
|
||||
DIV = "=" * L
|
||||
SUB = "-" * L
|
||||
|
||||
|
||||
def banner_line(text: str, width: int = L, fill: str = "=") -> str:
|
||||
"""Return a centered banner line with given width and fill."""
|
||||
text = f" {text.strip()} "
|
||||
fill_len = width - len(text)
|
||||
if fill_len <= 0:
|
||||
return text
|
||||
left = fill_len // 2
|
||||
right = fill_len - left
|
||||
return f"{fill * left}{text}{fill * right}"
|
||||
|
||||
|
||||
def print_wrapped(text: str, width: int = L):
|
||||
"""Print text with automatic line wrapping."""
|
||||
wrapped = textwrap.fill(text, width=width)
|
||||
print(wrapped)
|
||||
|
||||
|
||||
def boxed(text: str, width: int = L, padding: int = 1):
|
||||
"""Render a centered box with the given text and width."""
|
||||
lines = text.splitlines()
|
||||
content_width = max(len(line) for line in lines)
|
||||
box_width = min(width, content_width + padding * 2 + 2)
|
||||
|
||||
# compute left margin for centering
|
||||
terminal_width = width
|
||||
left_margin = max((terminal_width - box_width) // 2, 0)
|
||||
margin = " " * left_margin
|
||||
|
||||
top = margin + "┌" + "─" * (box_width - 2) + "┐"
|
||||
bottom = margin + "└" + "─" * (box_width - 2) + "┘"
|
||||
|
||||
print(top)
|
||||
for line in lines:
|
||||
inner = line.center(content_width)
|
||||
print(margin + "│" + " " * padding + inner + " " * padding + "│")
|
||||
print(bottom)
|
||||
|
||||
|
||||
def decode_trim(tokenizer, seq_ids_list, input_ids_list) -> str:
|
||||
"""
|
||||
Return only the generated text, truncated at the first EOS **after** the prompt.
|
||||
|
||||
Args:
|
||||
tokenizer: HF tokenizer with eos_token_id / pad_token_id.
|
||||
seq_ids: Full sequence token ids from the model (prompt + generation).
|
||||
input_ids: The prompt token ids that were fed into the model.
|
||||
|
||||
Behavior:
|
||||
- Finds the first eos_token_id that occurs at or after len(input_ids).
|
||||
- Slices generation up to (but not including) that EOS.
|
||||
- Decodes only the generation span, skipping special/pad tokens.
|
||||
"""
|
||||
# Make sure we can index these
|
||||
sequences = []
|
||||
for seq_ids, input_ids in zip(seq_ids_list, input_ids_list):
|
||||
full = list(seq_ids)
|
||||
prompt = list(input_ids)
|
||||
|
||||
# Skip left padding tokens (necessary for dream)
|
||||
pad_id = getattr(tokenizer, "pad_token_id", None)
|
||||
if pad_id is not None:
|
||||
while full and full[0] == pad_id:
|
||||
full.pop(0)
|
||||
|
||||
start = len(prompt)
|
||||
end = len(full)
|
||||
|
||||
eos_id = getattr(tokenizer, "eos_token_id", None)
|
||||
eot_id = getattr(tokenizer, "eot_token_id", None)
|
||||
if eos_id is not None:
|
||||
for i in range(start, len(full)):
|
||||
if full[i] in (eos_id, eot_id):
|
||||
end = i
|
||||
break
|
||||
|
||||
gen_ids = full[start:end]
|
||||
text = tokenizer.decode(gen_ids, skip_special_tokens=True)
|
||||
# in case there is no eos_id or eot_id, just strings
|
||||
eos = getattr(tokenizer, "eos_token", None)
|
||||
eot = getattr(tokenizer, "eot_token", None)
|
||||
if eos:
|
||||
text = text.split(eos)[0]
|
||||
if eot:
|
||||
text = text.split(eot)[0]
|
||||
# return text.strip()
|
||||
sequences.append(text)
|
||||
return sequences
|
||||
|
||||
|
||||
def render_menu(round_idx: int):
|
||||
"""Render a boxed menu of possible actions."""
|
||||
if round_idx == 0:
|
||||
text = (
|
||||
"Possible next actions:\n"
|
||||
"[1] Continue this chat\n"
|
||||
"[2] End this chat and start a new one\n"
|
||||
"[3] Exit"
|
||||
)
|
||||
else:
|
||||
text = (
|
||||
f"(Round {round_idx})\n"
|
||||
"Possible next actions:\n"
|
||||
"[1] Continue this chat\n"
|
||||
"[2] End this chat and start a new one\n"
|
||||
"[3] Exit"
|
||||
)
|
||||
|
||||
print() # spacing
|
||||
boxed(text)
|
||||
|
||||
|
||||
def prompt_choice() -> Literal["1", "2", "3"]:
|
||||
while True:
|
||||
print("Select action [1/2/3]: ")
|
||||
choice = input().strip()
|
||||
if choice in ("1", "2", "3"):
|
||||
return choice
|
||||
print(banner_line("<Invalid choice. Please type 1, 2, or 3.>", fill=" "))
|
||||
|
||||
|
||||
def build_chat_inputs(tokenizer, messages: List[dict], add_generation_prompt: bool):
|
||||
"""Tokenize chat messages into inputs tensor."""
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
|
||||
def visualize_histories(tokenizer, histories):
|
||||
try:
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
terminal_visualizer.visualize(histories, rich=True)
|
||||
except Exception as e:
|
||||
print(f"(Visualization skipped: {e})")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Modes
|
||||
# ============================================================
|
||||
def single_turn_generate(generator, gen_config, visualize: bool):
|
||||
print()
|
||||
print(banner_line("continuation mode"))
|
||||
model, tokenizer = generator.model, generator.tokenizer
|
||||
|
||||
while True:
|
||||
print(banner_line("<Type your prompt below. Press Ctrl+C to exit.>", fill=" "))
|
||||
try:
|
||||
# user_text = input("Prompt > ").strip()
|
||||
print("[Prompt] > ")
|
||||
user_text = input().strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n" + banner_line("Exiting. Bye!", width=len(DIV)))
|
||||
return
|
||||
|
||||
# if not user_text:
|
||||
# print("(Empty input, skipped)\n")
|
||||
# continue
|
||||
|
||||
inputs = tokenizer([user_text], add_special_tokens=False)["input_ids"]
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
text = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0]
|
||||
|
||||
print(banner_line("Output"))
|
||||
print_wrapped(text if text else "<empty>")
|
||||
print(DIV + "\n")
|
||||
|
||||
if visualize:
|
||||
visualize_histories(tokenizer, outputs.histories)
|
||||
|
||||
|
||||
def multi_turn_chat(generator, gen_config, visualize: bool):
|
||||
# """Chat mode with chat template & message history."""
|
||||
print()
|
||||
print(banner_line("multi-turn chat mode"))
|
||||
print(banner_line("<Starting a new chat. Type your message.>", fill=" "))
|
||||
model, tokenizer = generator.model, generator.tokenizer
|
||||
|
||||
messages: List[dict] = []
|
||||
round_idx = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
print("[You]:")
|
||||
user_msg = input().strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nExiting. Bye!")
|
||||
return
|
||||
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
inputs = build_chat_inputs(tokenizer, [messages], add_generation_prompt=True)
|
||||
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
reply = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0]
|
||||
|
||||
print(DIV)
|
||||
print_wrapped("[Assistant]: " + reply if reply else "<empty>")
|
||||
print(DIV + "\n")
|
||||
|
||||
messages.append({"role": "assistant", "content": reply})
|
||||
|
||||
if visualize:
|
||||
visualize_histories(tokenizer, outputs.histories)
|
||||
|
||||
render_menu(round_idx)
|
||||
choice = prompt_choice()
|
||||
if choice == "1":
|
||||
print(banner_line("<Type your message.>", fill=" "))
|
||||
round_idx += 1
|
||||
continue
|
||||
elif choice == "2":
|
||||
print(banner_line("<Starting a new chat. Type your message.>", fill=" "))
|
||||
messages = []
|
||||
round_idx = 0
|
||||
continue
|
||||
else:
|
||||
print("\nExiting. Bye!")
|
||||
return
|
||||
30
dllm/dllm/tools/download_hf_dataset.py
Normal file
@ -0,0 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import tyro
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
dataset_id: str = "Anthropic/hh-rlhf"
|
||||
allow_patterns: str = None
|
||||
|
||||
|
||||
script_args = tyro.cli(ScriptArguments)
|
||||
|
||||
# Replace with the dataset repo you want, e.g. "wikitext"
|
||||
dataset_id = script_args.dataset_id
|
||||
|
||||
# Replace with your desired local directory
|
||||
local_dir = f"/mnt/lustrenew/mllm_aligned/shared/datasets/huggingface/{dataset_id}"
|
||||
|
||||
# Download the dataset snapshot
|
||||
snapshot_download(
|
||||
repo_id=dataset_id,
|
||||
repo_type="dataset", # 👈 tell HF it's a dataset
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=False, # ensures real files, not symlinks
|
||||
allow_patterns=script_args.allow_patterns,
|
||||
)
|
||||
|
||||
print(f"Dataset downloaded to: {local_dir}")
|
||||
27
dllm/dllm/tools/download_hf_model.py
Normal file
@ -0,0 +1,27 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import tyro
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_id: str = "GSAI-ML/LLaDA-8B-Instruct"
|
||||
|
||||
|
||||
script_args = tyro.cli(ScriptArguments)
|
||||
|
||||
# Replace with the model repo you want, e.g. "bert-base-uncased"
|
||||
model_id = script_args.model_id
|
||||
|
||||
# Replace with your desired local directory
|
||||
local_dir = f"/mnt/lustrenew/mllm_aligned/shared/models/huggingface/{model_id}"
|
||||
|
||||
# Download the model snapshot
|
||||
snapshot_download(
|
||||
repo_id=model_id,
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=False, # ensures real files, not symlinks
|
||||
)
|
||||
|
||||
print(f"Model downloaded to: {local_dir}")
|
||||
1
dllm/dllm/tools/generate.py
Normal file
@ -0,0 +1 @@
|
||||
# TODO
|
||||
80
dllm/dllm/tools/merge_peft_adapter.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
Merge a PEFT/LoRA adapter into its base model (auto-detected from adapter_config.json).
|
||||
|
||||
Usage:
|
||||
python dllm_trainer/tools/merge_peft_adapter.py \
|
||||
--adapter_model_name_or_path your-org/your-lora \
|
||||
--output_model_name_or_path ./merged-model \
|
||||
--dtype bf16
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModel, AutoTokenizer, HfArgumentParser
|
||||
|
||||
import dllm # so that no need to trust_remote_code
|
||||
|
||||
DTYPE_MAP = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
adapter_model_name_or_path: str | None = field(
|
||||
default=None, metadata={"help": "Adapter repo or local path"}
|
||||
)
|
||||
output_model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to save the merged model (folder or repo id)"},
|
||||
)
|
||||
dtype: str | None = field(default="fp16", metadata={"help": "fp16|bf16|fp32"})
|
||||
push_to_hub: bool | None = field(
|
||||
default=False, metadata={"help": "Push merged weights to the Hub"}
|
||||
)
|
||||
# Optional override if adapter config lacks base info:
|
||||
base_model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Override base model if adapter config lacks it"},
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
assert args.adapter_model_name_or_path, "please provide the adapter (repo or path)"
|
||||
assert args.output_model_name_or_path, "please provide output_model_name_or_path"
|
||||
assert args.dtype in DTYPE_MAP, f"dtype must be one of {list(DTYPE_MAP.keys())}"
|
||||
|
||||
# Read base path from adapter_config.json
|
||||
peft_cfg = PeftConfig.from_pretrained(args.adapter_model_name_or_path)
|
||||
base_id = args.base_model_name_or_path or getattr(
|
||||
peft_cfg, "base_model_name_or_path", None
|
||||
)
|
||||
assert base_id, (
|
||||
"adapter_config.json does not include base_model_name_or_path; "
|
||||
"pass --base_model_name_or_path to override."
|
||||
)
|
||||
|
||||
# Load base model and tokenizer
|
||||
model = AutoModel.from_pretrained(
|
||||
base_id, return_dict=True, dtype=DTYPE_MAP[args.dtype]
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_id)
|
||||
|
||||
# Attach adapter, merge, and unload PEFT layers
|
||||
model = PeftModel.from_pretrained(model, args.adapter_model_name_or_path)
|
||||
model.eval()
|
||||
model = model.merge_and_unload() # plain transformers model
|
||||
|
||||
# Save locally
|
||||
model.save_pretrained(args.output_model_name_or_path)
|
||||
tokenizer.save_pretrained(args.output_model_name_or_path)
|
||||
|
||||
print(f"✓ merged model saved to: {args.output_model_name_or_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
dllm/dllm/tools/preprocess_pt_dataset.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""
|
||||
python dllm/tools/preprocess_pt_dataset.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, asdict
|
||||
from functools import partial
|
||||
|
||||
import datasets
|
||||
import tyro
|
||||
import transformers
|
||||
from pprint import pprint
|
||||
|
||||
import dllm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""Preprocess PT dataset"""
|
||||
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
dataset_args: str = "OpenCoder-LLM/opc-annealing-corpus[lang:python]" # required
|
||||
output_dir: str = "data/pt/modernbert/opc-annealing-corpus[lang:python]" # required
|
||||
text_field: str = "text"
|
||||
max_length: int = 1024
|
||||
insert_eos: bool = True
|
||||
drop_tail: bool = True
|
||||
remove_columns: bool = False
|
||||
num_proc: int = 32
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
def preprocess_pt_dataset(
|
||||
dataset: datasets.DatasetDict,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
output_dir: str,
|
||||
text_field: str = "text",
|
||||
max_length: int = 1024,
|
||||
insert_eos: bool = True,
|
||||
drop_tail: bool = True,
|
||||
remove_columns: bool = False,
|
||||
num_proc: int = 32,
|
||||
):
|
||||
processed = dataset.map(
|
||||
partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=text_field,
|
||||
seq_length=max_length,
|
||||
insert_eos=insert_eos,
|
||||
drop_tail=drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
)
|
||||
|
||||
# Keep only the three required columns to save space.
|
||||
if remove_columns:
|
||||
keep = {"input_ids", "labels"}
|
||||
|
||||
def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
|
||||
drop = [c for c in ds.column_names if c not in keep]
|
||||
return ds.remove_columns(drop) if drop else ds
|
||||
|
||||
if isinstance(processed, datasets.DatasetDict):
|
||||
for split in list(processed.keys()):
|
||||
processed[split] = strip_cols(processed[split])
|
||||
else:
|
||||
processed = strip_cols(processed)
|
||||
|
||||
output_dir = os.path.join(
|
||||
output_dir,
|
||||
f"max_length-{max_length}-insert_eos-{insert_eos}-drop_tail-{drop_tail}",
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
processed.save_to_disk(output_dir)
|
||||
print(f"[OK] Saved to: {output_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse with tyro
|
||||
args = tyro.cli(ScriptArguments)
|
||||
dllm.utils.print_args(args)
|
||||
|
||||
tokenizer = dllm.utils.get_tokenizer(args)
|
||||
|
||||
# Load your raw dataset (must contain a "messages" field per example).
|
||||
dataset = dllm.data.load_pt_dataset(args.dataset_args, streaming=False)
|
||||
|
||||
preprocess_pt_dataset(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
output_dir=args.output_dir,
|
||||
text_field=args.text_field,
|
||||
max_length=args.max_length,
|
||||
insert_eos=args.insert_eos,
|
||||
drop_tail=args.drop_tail,
|
||||
remove_columns=args.remove_columns,
|
||||
num_proc=args.num_proc,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
117
dllm/dllm/tools/preprocess_sft_dataset.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
Example:
|
||||
|
||||
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
|
||||
--num_proc 64
|
||||
"""
|
||||
|
||||
import os
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
||||
import datasets
|
||||
import tyro
|
||||
|
||||
import dllm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""Preprocess SFT dataset"""
|
||||
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
|
||||
sft_map_fn_path: str = "dllm.utils.default_sft_map_fn"
|
||||
dataset_args: str = "HuggingFaceTB/smoltalk" # required
|
||||
output_dir: str = "data/sft/llada/smoltalk" # required
|
||||
mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100
|
||||
num_proc: int = 32
|
||||
remove_columns: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
def preprocess_sft_dataset(
|
||||
dataset: datasets.DatasetDict,
|
||||
map_fn: callable,
|
||||
output_dir: str,
|
||||
remove_columns: bool = False,
|
||||
num_proc: int = 32,
|
||||
):
|
||||
processed = dataset.map(
|
||||
map_fn,
|
||||
batched=False,
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=True,
|
||||
writer_batch_size=512,
|
||||
desc="offline preprocessing",
|
||||
)
|
||||
|
||||
# Keep only the three required columns to save space.
|
||||
if remove_columns:
|
||||
keep = {"input_ids", "labels", "prompt_len", "attention_mask"}
|
||||
|
||||
def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
|
||||
drop = [c for c in ds.column_names if c not in keep]
|
||||
return ds.remove_columns(drop) if drop else ds
|
||||
|
||||
if isinstance(processed, datasets.DatasetDict):
|
||||
for split in list(processed.keys()):
|
||||
processed[split] = strip_cols(processed[split])
|
||||
else:
|
||||
processed = strip_cols(processed)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
processed.save_to_disk(output_dir)
|
||||
print(f"[OK] Saved to: {output_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse with tyro
|
||||
args = tyro.cli(ScriptArguments)
|
||||
dllm.utils.print_args(args)
|
||||
|
||||
tokenizer = dllm.utils.get_tokenizer(args)
|
||||
|
||||
# Load your raw dataset (must contain a "messages" field per example).
|
||||
dataset = dllm.data.load_sft_dataset(args.dataset_args)
|
||||
|
||||
# 4. Dynamically import the function based on the argument
|
||||
try:
|
||||
# Split the path into module and function name
|
||||
module_path, function_name = args.sft_map_fn_path.rsplit(".", 1)
|
||||
|
||||
# Import the module
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Get the function from the module
|
||||
sft_map_fn = getattr(module, function_name)
|
||||
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
print(f"Error: Could not import '{args.sft_map_fn_path}'.")
|
||||
print(f"Details: {e}")
|
||||
return
|
||||
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=args.mask_prompt_loss,
|
||||
)
|
||||
preprocess_sft_dataset(
|
||||
dataset=dataset,
|
||||
map_fn=map_fn,
|
||||
output_dir=args.output_dir,
|
||||
remove_columns=args.remove_columns,
|
||||
num_proc=args.num_proc,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
dllm/dllm/utils/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from . import configs, generation_utils, model_utils, utils
|
||||
from .configs import *
|
||||
from .generation_utils import *
|
||||
from .data_utils import *
|
||||
from .model_utils import *
|
||||
from .utils import *
|
||||
77
dllm/dllm/utils/configs.py
Normal file
@ -0,0 +1,77 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
|
||||
from dllm.utils.utils import resolve_with_base_env, get_default_logger
|
||||
|
||||
logger = get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = None # overwrite this
|
||||
dtype: str = "bfloat16"
|
||||
load_in_4bit: bool = False
|
||||
attn_implementation: str = None
|
||||
# --- fold PEFT args here ---
|
||||
lora: bool = False
|
||||
target_modules: str = "all-linear"
|
||||
r: int = 32
|
||||
lora_alpha: int = 64
|
||||
lora_dropout: float = 0.05
|
||||
bias: str = "none"
|
||||
modules_to_save: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset_args: str = None # overwrite this
|
||||
num_proc: int = 8
|
||||
disable_caching: bool = False
|
||||
max_length: int = 1024
|
||||
truncation: str = field(
|
||||
default="right",
|
||||
metadata={
|
||||
"help": (
|
||||
'The truncation strategy to use ("filter" or "right"). '
|
||||
'"filter" only keeps sequences that are shorter than max_length; '
|
||||
'"right" only keeps the rightmost max_length tokens for each sequence.'
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(transformers.TrainingArguments):
|
||||
output_dir: str = None # overwrite this
|
||||
report_to: str = "wandb"
|
||||
overwrite_output_dir: bool = True
|
||||
seed: int = 42
|
||||
per_device_train_batch_size: int = 4
|
||||
per_device_eval_batch_size: int = 4
|
||||
gradient_accumulation_steps: int = 1
|
||||
learning_rate: float = 2e-5
|
||||
lr_scheduler_type: str = "cosine"
|
||||
warmup_ratio: float = 0.1
|
||||
bf16: bool = True
|
||||
num_train_epochs: float = 4
|
||||
logging_steps: float = 10
|
||||
eval_on_start: bool = False
|
||||
eval_strategy: str = "steps"
|
||||
eval_steps: float = 0.25
|
||||
save_steps: float = 0.25
|
||||
save_only_model: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.group_by_length:
|
||||
logger.info(
|
||||
"training_args.group_by_length=True: preprocessing "
|
||||
"may take some time after `trainer.train()` starts."
|
||||
)
|
||||
222
dllm/dllm/utils/data_utils.py
Normal file
@ -0,0 +1,222 @@
|
||||
import random
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import datasets
|
||||
import transformers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments
|
||||
|
||||
|
||||
def tokenize_and_group(
|
||||
examples,
|
||||
tokenizer,
|
||||
text_field: str = "text",
|
||||
seq_length: int = 1024,
|
||||
insert_eos: bool = False,
|
||||
drop_tail: bool = True,
|
||||
add_special_tokens: bool = False,
|
||||
):
|
||||
# 1) Tokenize (batched input)
|
||||
tokenized = tokenizer(examples[text_field], add_special_tokens=add_special_tokens)
|
||||
ids = tokenized["input_ids"]
|
||||
|
||||
# --- optionally append EOS to each sample ---
|
||||
if insert_eos:
|
||||
eos_id = getattr(tokenizer, "eos_token_id")
|
||||
assert eos_id
|
||||
# append EOS only if the sample doesn't already end with it
|
||||
ids = [seq + ([] if (seq and seq[-1] == eos_id) else [eos_id]) for seq in ids]
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
# 2) Flatten and concatenate all token lists
|
||||
concatenated = list(chain.from_iterable(ids))
|
||||
if not concatenated:
|
||||
return {"input_ids": [], "labels": []} # Safe return for empty batch
|
||||
|
||||
# 3) Calculate the total length based on drop_tail
|
||||
if drop_tail:
|
||||
total_len = (len(concatenated) // seq_length) * seq_length
|
||||
concatenated = concatenated[:total_len] # Truncate the last incomplete chunk
|
||||
else:
|
||||
total_len = len(concatenated)
|
||||
|
||||
# Split into fixed-length chunks
|
||||
chunks = [concatenated[i : i + seq_length] for i in range(0, total_len, seq_length)]
|
||||
|
||||
return {
|
||||
"input_ids": chunks,
|
||||
"labels": [c[:] for c in chunks], # Labels are the same as input_ids
|
||||
}
|
||||
|
||||
|
||||
def clip_row(row: dict, max_length: int, truncation: str = "right") -> dict:
|
||||
for key in ("input_ids", "labels", "attention_mask"):
|
||||
if key in row:
|
||||
if truncation == "right":
|
||||
row[key] = row[key][:max_length]
|
||||
elif truncation == "left":
|
||||
row[key] = row[key][-max_length:]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return row
|
||||
|
||||
|
||||
def post_process_dataset(
|
||||
dataset: datasets.DatasetDict, data_args: "DataArguments"
|
||||
) -> datasets.DatasetDict:
|
||||
if data_args.truncation == "filter":
|
||||
return dataset.filter(
|
||||
lambda row: len(row["input_ids"]) <= data_args.max_length,
|
||||
num_proc=data_args.num_proc,
|
||||
desc=f"Filtering samples with length <= {data_args.max_length}",
|
||||
)
|
||||
elif data_args.truncation == "right":
|
||||
# do this only if dataset has "prompt_len"
|
||||
if "prompt_len" in dataset.column_names["train"]:
|
||||
dataset = dataset.filter(
|
||||
lambda row: row["prompt_len"] <= data_args.max_length,
|
||||
num_proc=data_args.num_proc,
|
||||
desc=f"Filtering samples with `prompt_len` <= {data_args.max_length}",
|
||||
)
|
||||
return dataset.map(
|
||||
lambda row: clip_row(row, data_args.max_length, truncation="right"),
|
||||
num_proc=data_args.num_proc,
|
||||
desc=f"Right-truncating samples to max_length={data_args.max_length}",
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def clip_row_streaming(row: dict, max_length: int, truncation: str = "right") -> dict:
|
||||
"""Clip whole sequence OR (if prompt_len present) preserve prompt and clip only the response."""
|
||||
if truncation not in {"right", "left"}:
|
||||
raise NotImplementedError(f"Unknown truncation: {truncation}")
|
||||
|
||||
def clip(seq):
|
||||
return seq[:max_length] if truncation == "right" else seq[-max_length:]
|
||||
|
||||
def clip_preserve_prompt(seq, prompt_len: int):
|
||||
prompt = seq[:prompt_len]
|
||||
resp = seq[prompt_len:]
|
||||
budget = max(0, max_length - len(prompt))
|
||||
resp = resp[:budget] if truncation == "right" else resp[-budget:]
|
||||
return prompt + resp
|
||||
|
||||
prompt_len = row.get("prompt_len", None)
|
||||
for k in ("input_ids", "labels", "attention_mask"):
|
||||
if k in row and isinstance(row[k], list):
|
||||
row[k] = (
|
||||
clip_preserve_prompt(row[k], prompt_len)
|
||||
if isinstance(prompt_len, int) and prompt_len >= 0
|
||||
else clip(row[k])
|
||||
)
|
||||
return row
|
||||
|
||||
|
||||
def post_process_dataset_streaming(
|
||||
dataset: datasets.IterableDatasetDict,
|
||||
data_args: "DataArguments",
|
||||
) -> datasets.IterableDatasetDict:
|
||||
|
||||
def _train_has_prompt_len_streaming(dataset: datasets.IterableDatasetDict) -> bool:
|
||||
"""Replicates: 'if \"prompt_len\" in dataset.column_names[\"train\"]' for streaming."""
|
||||
it = dataset["train"].take(1)
|
||||
try:
|
||||
ex = next(iter(it))
|
||||
except StopIteration:
|
||||
return False
|
||||
return "prompt_len" in ex
|
||||
|
||||
mode = data_args.truncation
|
||||
max_len = data_args.max_length
|
||||
|
||||
if mode == "filter":
|
||||
# Keep rows with len(input_ids) <= max_len (emulate .filter with generator map)
|
||||
def keep_if_short(row):
|
||||
if (
|
||||
"input_ids" in row
|
||||
and isinstance(row["input_ids"], list)
|
||||
and len(row["input_ids"]) <= max_len
|
||||
):
|
||||
yield row # keep
|
||||
# else: drop (yield nothing)
|
||||
|
||||
return datasets.IterableDatasetDict(
|
||||
{name: ds.map(keep_if_short) for name, ds in dataset.items()}
|
||||
)
|
||||
|
||||
elif mode == "right":
|
||||
ds_out = dataset
|
||||
|
||||
# Do this only if TRAIN split has "prompt_len" (same condition as your non-streaming code)
|
||||
if _train_has_prompt_len_streaming(ds_out):
|
||||
|
||||
def keep_if_prompt_fits(row):
|
||||
pl = row.get("prompt_len", None)
|
||||
if isinstance(pl, int) and pl <= max_len:
|
||||
yield row # keep
|
||||
elif pl is None:
|
||||
# If a row lacks prompt_len but train had it, the non-streaming code would try to access it and fail.
|
||||
# Here we conservatively drop such rows to mirror "requires prompt_len <= max_len".
|
||||
return
|
||||
# else: drop
|
||||
|
||||
ds_out = datasets.IterableDatasetDict(
|
||||
{name: ds.map(keep_if_prompt_fits) for name, ds in ds_out.items()}
|
||||
)
|
||||
|
||||
# Then clip right (same clipping as clip_row)
|
||||
def clip_right(row):
|
||||
return clip_row(row, max_len, truncation="right")
|
||||
|
||||
return datasets.IterableDatasetDict(
|
||||
{name: ds.map(clip_right) for name, ds in ds_out.items()}
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoAttentionMaskCollator(transformers.DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors)
|
||||
# fintune on padding <eos_token>; should not mask them out
|
||||
outputs.pop("attention_mask")
|
||||
return outputs
|
||||
|
||||
|
||||
def default_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
|
||||
"""
|
||||
Build input_ids and labels for SFT.
|
||||
|
||||
Args:
|
||||
row: a dataset row with `messages`
|
||||
tokenizer: a HF tokenizer
|
||||
mask_prompt_loss: whether to mask prompt tokens (set their labels to -100)
|
||||
|
||||
Returns:
|
||||
dict with keys: input_ids, labels, and optionally prompt_len
|
||||
"""
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
|
||||
if mask_prompt_loss:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
return {"input_ids": prompt_response_tokens, "labels": labels}
|
||||
53
dllm/dllm/utils/generation_utils.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
|
||||
from dllm.core.schedulers import BaseAlphaScheduler
|
||||
|
||||
|
||||
def get_num_transfer_tokens(
|
||||
mask_index: torch.Tensor,
|
||||
steps: int,
|
||||
scheduler: BaseAlphaScheduler,
|
||||
stochastic: bool = False,
|
||||
) -> torch.Tensor:
|
||||
mask_num = mask_index.sum(dim=1, keepdim=True)
|
||||
num_transfer_tokens = torch.zeros(
|
||||
mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64
|
||||
)
|
||||
for i in range(mask_num.size(0)):
|
||||
for t, s, j in zip(range(steps, 0, -1), range(steps - 1, -1, -1), range(steps)):
|
||||
s /= steps
|
||||
t /= steps
|
||||
reverse_transfer_prob = 1 - scheduler.reverse_mask_prob(s=s, t=t)
|
||||
if not stochastic:
|
||||
x = mask_num[i, 0].to(torch.float64) * reverse_transfer_prob
|
||||
num_transfer_tokens[i, j] = torch.round(x).to(torch.int64)
|
||||
else:
|
||||
n = mask_num[i, 0].to(torch.float64)
|
||||
num_transfer_tokens[i, j] = (
|
||||
torch.distributions.Binomial(n, reverse_transfer_prob)
|
||||
.sample()
|
||||
.to(torch.int64)
|
||||
)
|
||||
num_transfer_tokens[i, j] = torch.minimum(
|
||||
num_transfer_tokens[i, j], mask_num[i, 0]
|
||||
)
|
||||
mask_num[i, 0] -= num_transfer_tokens[i, j]
|
||||
if mask_num[i, 0].item() == 0:
|
||||
break
|
||||
# Note: because llada is not conditioned on time, this allows us to skip steps with no unmasking (i.e. transfer).
|
||||
# Clear all zeros per row (compact) and right-pad with zeros
|
||||
# Remove zeros per row, then pad only up to the max length across rows
|
||||
rows = []
|
||||
max_len = 0
|
||||
for i in range(num_transfer_tokens.size(0)):
|
||||
nonzero = num_transfer_tokens[i][num_transfer_tokens[i] > 0]
|
||||
rows.append(nonzero)
|
||||
max_len = max(max_len, nonzero.numel())
|
||||
# Pad each row to max_len
|
||||
padded_rows = []
|
||||
for r in rows:
|
||||
if r.numel() < max_len:
|
||||
pad = torch.zeros(max_len - r.numel(), dtype=r.dtype, device=r.device)
|
||||
r = torch.cat([r, pad])
|
||||
padded_rows.append(r)
|
||||
return torch.stack(padded_rows, dim=0)
|
||||
180
dllm/dllm/utils/model_utils.py
Normal file
@ -0,0 +1,180 @@
|
||||
import torch
|
||||
import accelerate
|
||||
import transformers
|
||||
from peft import prepare_model_for_kbit_training
|
||||
|
||||
from dllm.utils.utils import disable_caching_allocator_warmup, print_main, load_peft
|
||||
from dllm.utils.configs import ModelArguments, TrainingArguments
|
||||
|
||||
|
||||
def get_model(
|
||||
model_args,
|
||||
config: transformers.PretrainedConfig | None = None,
|
||||
) -> transformers.PreTrainedModel:
|
||||
"""
|
||||
Load a model with flexible input sources.
|
||||
|
||||
Args:
|
||||
model_args: An optional dataclass or namespace containing model parameters.
|
||||
model_name_or_path: Optional direct model path or name (overrides model_args.model_name_or_path).
|
||||
dtype: Dtype (string or torch.dtype).
|
||||
load_in_4bit: Whether to load using 4-bit quantization (can override model_args.load_in_4bit).
|
||||
|
||||
Returns:
|
||||
transformers.PreTrainedModel
|
||||
"""
|
||||
model_name_or_path = getattr(model_args, "model_name_or_path")
|
||||
dtype = getattr(model_args, "dtype", "bfloat16")
|
||||
load_in_4bit = getattr(model_args, "load_in_4bit", False)
|
||||
attn_implementation = getattr(model_args, "attn_implementation", None)
|
||||
|
||||
# Device map: skip when ZeRO-3
|
||||
device_map = (
|
||||
{"": accelerate.PartialState().local_process_index}
|
||||
if not transformers.modeling_utils.is_deepspeed_zero3_enabled()
|
||||
and torch.cuda.is_available()
|
||||
else None
|
||||
)
|
||||
|
||||
quant_config = None
|
||||
if load_in_4bit and transformers.utils.is_bitsandbytes_available():
|
||||
quant_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
params = {
|
||||
"dtype": dtype,
|
||||
"device_map": device_map,
|
||||
"quantization_config": quant_config,
|
||||
"attn_implementation": attn_implementation,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
try:
|
||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_name_or_path, **params
|
||||
)
|
||||
except:
|
||||
model = transformers.AutoModel.from_pretrained(model_name_or_path, **params)
|
||||
|
||||
# --- if quantized, prepare for LoRA / QLoRA training ---
|
||||
if load_in_4bit and quant_config is not None:
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
|
||||
|
||||
# Optionally train with lora
|
||||
model = load_peft(model, model_args)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer:
|
||||
"""
|
||||
Load a tokenizer with flexible input sources.
|
||||
|
||||
Args:
|
||||
model_args: Optional dataclass or namespace containing model parameters.
|
||||
model: Optional model instance to configure tokenizer behavior.
|
||||
model_name_or_path: Optional direct model name or path (overrides model_args.model_name_or_path).
|
||||
|
||||
Returns:
|
||||
transformers.PreTrainedTokenizer
|
||||
"""
|
||||
# Lazy imports to avoid circular dependencies
|
||||
from dllm.pipelines.llada.models.modeling_llada import LLaDAModelLM
|
||||
from dllm.pipelines.llada.models.modeling_lladamoe import LLaDAMoEModelLM
|
||||
from dllm.pipelines.dream.models.modeling_dream import DreamModel
|
||||
from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
|
||||
from transformers import (
|
||||
BertPreTrainedModel,
|
||||
RobertaPreTrainedModel,
|
||||
ModernBertPreTrainedModel,
|
||||
)
|
||||
|
||||
model_name_or_path = getattr(model_args, "model_name_or_path")
|
||||
|
||||
# ---------------- Tokenizer loading ----------------
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
padding_side="right",
|
||||
)
|
||||
|
||||
assert tokenizer.eos_token != None or tokenizer.pad_token != None
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if not tokenizer.eos_token:
|
||||
tokenizer.eos_token = tokenizer.pad_token
|
||||
if not tokenizer.bos_token:
|
||||
tokenizer.bos_token = tokenizer.pad_token
|
||||
|
||||
# If model is not provided, return as-is
|
||||
model_cfg = transformers.AutoConfig.from_pretrained(model_name_or_path)
|
||||
model_cls = transformers.AutoModel._model_mapping[type(model_cfg)]
|
||||
|
||||
# ---------------- Model-specific customization ----------------
|
||||
if issubclass(model_cls, LLaDAModelLM):
|
||||
tokenizer.add_special_tokens({"mask_token": "<|mdm_mask|>"})
|
||||
tokenizer.eot_token = "<|eot_id|>"
|
||||
# tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) # can not do this for llada base directly
|
||||
# TODO: for llada base, add special_tokens = {"<|start_header_id|>": 126346, "<|end_header_id|>": 126347, "<|eot_id|>": 126348}
|
||||
# fix bugs in chat template
|
||||
tokenizer.chat_template = """\
|
||||
{% set loop_messages = messages %}
|
||||
{% for message in loop_messages %}
|
||||
{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}
|
||||
<|start_header_id|>{{ message['role'] }}<|end_header_id|>
|
||||
|
||||
{{ message['content'] | trim }}<|eot_id|>
|
||||
{%- endfor %}
|
||||
{% if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{% endif %}
|
||||
"""
|
||||
elif issubclass(model_cls, LLaDAMoEModelLM):
|
||||
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
|
||||
tokenizer.eot_token = "<|role_end|>"
|
||||
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
|
||||
elif issubclass(model_cls, DreamModel):
|
||||
tokenizer.eot_token = "<|im_end|>"
|
||||
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
|
||||
elif issubclass(model_cls, RND1LM):
|
||||
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
|
||||
elif issubclass(
|
||||
model_cls,
|
||||
(BertPreTrainedModel, RobertaPreTrainedModel, ModernBertPreTrainedModel),
|
||||
):
|
||||
tokenizer.eot_token = "[/Answer]"
|
||||
tokenizer.chat_template = """\
|
||||
{% if messages[0]['role'] == 'system' %}
|
||||
[SYS]
|
||||
{{ messages[0]['content'] | trim }}
|
||||
[/SYS]
|
||||
|
||||
{% set loop_messages = messages[1:] %}
|
||||
{% else %}
|
||||
{% set loop_messages = messages %}
|
||||
{% endif -%}
|
||||
{%- for message in loop_messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
[Question]
|
||||
{{ message['content'] | trim }}
|
||||
[/Question]
|
||||
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
[Answer]
|
||||
{{ message['content'] | trim }}
|
||||
[/Answer]
|
||||
|
||||
{% endif %}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
|
||||
[Answer]
|
||||
{% endif %}
|
||||
"""
|
||||
else:
|
||||
print_main("no tokenizer customization for model class:", model_cls)
|
||||
return tokenizer
|
||||
284
dllm/dllm/utils/utils.py
Normal file
@ -0,0 +1,284 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments
|
||||
|
||||
import pprint
|
||||
import torch
|
||||
import peft
|
||||
import accelerate
|
||||
import transformers
|
||||
|
||||
|
||||
def resolve_with_base_env(path: str, env_name: str) -> str:
|
||||
"""
|
||||
If `env_name` is set and `path` is NOT absolute, NOT a URL/scheme,
|
||||
and does not already exist locally, prepend the `env_name` directory.
|
||||
|
||||
If the resulting path does not exist, return the base environment directory instead.
|
||||
Otherwise return `path` unchanged.
|
||||
"""
|
||||
base = os.getenv(env_name, "").strip()
|
||||
if not base:
|
||||
return path
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
|
||||
candidate = os.path.join(base.rstrip("/"), path.lstrip("/"))
|
||||
if os.path.exists(candidate):
|
||||
return candidate
|
||||
else:
|
||||
raise FileNotFoundError
|
||||
|
||||
|
||||
@contextmanager
|
||||
def init_device_context_manager(device: str | torch.device | None = None):
|
||||
"""
|
||||
Temporarily set torch default dtype and default device so that tensors
|
||||
created inside the context are allocated on `device` with dtype `dtype`.
|
||||
Restores previous settings on exit.
|
||||
"""
|
||||
if transformers.integrations.is_deepspeed_zero3_enabled():
|
||||
yield
|
||||
return
|
||||
|
||||
# Resolve device
|
||||
if device is None:
|
||||
try:
|
||||
from accelerate import PartialState
|
||||
|
||||
idx = PartialState().local_process_index
|
||||
except Exception:
|
||||
idx = 0
|
||||
device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
|
||||
elif isinstance(device, int):
|
||||
device = f"cuda:{device}"
|
||||
|
||||
try:
|
||||
torch.set_default_device(device)
|
||||
yield
|
||||
finally:
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
|
||||
def print_main(*args, **kwargs):
|
||||
"""
|
||||
Print only from the global main process (rank 0 across all nodes).
|
||||
Usage: print_main("Hello from main process!")
|
||||
"""
|
||||
if accelerate.PartialState().is_main_process:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def pprint_main(*args, **kwargs):
|
||||
"""
|
||||
Print (with pprint) only from the global main process (rank 0 across all nodes).
|
||||
Usage: print_main("Hello from main process!")
|
||||
"""
|
||||
if accelerate.PartialState().is_main_process:
|
||||
pprint.pprint(*args, **kwargs)
|
||||
|
||||
|
||||
def load_peft(
|
||||
model: transformers.PreTrainedModel, model_args: "ModelArguments"
|
||||
) -> transformers.PreTrainedModel:
|
||||
"""
|
||||
e.g.,
|
||||
--modules_to_save "lm_head" --target_modules "q_proj,k_proj,v_proj,o_proj,up_proj,down_proj,gate_proj"
|
||||
--target_modules "all-linear"
|
||||
"""
|
||||
if not getattr(model_args, "lora", False):
|
||||
return model
|
||||
target_modules = (
|
||||
model_args.target_modules.split(",") if model_args.target_modules else None
|
||||
)
|
||||
# if it’s a single 'all-linear', drop the list and use the string directly
|
||||
if (
|
||||
target_modules
|
||||
and len(target_modules) == 1
|
||||
and target_modules[0].strip() == "all-linear"
|
||||
):
|
||||
target_modules = target_modules[0]
|
||||
modules_to_save = (
|
||||
model_args.modules_to_save.split(",") if model_args.modules_to_save else None
|
||||
)
|
||||
peft_config = peft.LoraConfig(
|
||||
r=model_args.r,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
bias=model_args.bias,
|
||||
modules_to_save=modules_to_save,
|
||||
)
|
||||
model = peft.get_peft_model(model, peft_config)
|
||||
if accelerate.PartialState().is_main_process:
|
||||
print(model)
|
||||
model.print_trainable_parameters()
|
||||
return model
|
||||
|
||||
|
||||
def print_args_main(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "TrainingArguments",
|
||||
):
|
||||
print_main("\n===== Parsed arguments =====")
|
||||
for name, args in [
|
||||
("model_args", model_args),
|
||||
("data_args", data_args),
|
||||
("training_args", training_args),
|
||||
]:
|
||||
d = asdict(args)
|
||||
# keep it tiny: just show first few entries
|
||||
short = {k: d[k] for k in list(d)} # adjust number as you like
|
||||
print_main(f"{name}:")
|
||||
pprint_main(short, width=100, compact=True, sort_dicts=False)
|
||||
print_main("============================\n")
|
||||
|
||||
|
||||
def print_args(args):
|
||||
print_main("\n===== Parsed arguments =====")
|
||||
d = asdict(args)
|
||||
# keep it tiny: just show first few entries
|
||||
short = {k: d[k] for k in list(d)} # adjust number as you like
|
||||
pprint_main(short, width=100, compact=True, sort_dicts=False)
|
||||
print_main("============================\n")
|
||||
|
||||
|
||||
def disable_caching_allocator_warmup():
|
||||
try:
|
||||
from transformers import modeling_utils as _mu
|
||||
|
||||
def _noop(*args, **kwargs):
|
||||
return
|
||||
|
||||
_mu.caching_allocator_warmup = _noop
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def disable_dataset_progress_bar_except_main():
|
||||
# state = accelerate.PartialState() # figures out your rank/world automatically
|
||||
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
||||
|
||||
if accelerate.PartialState().is_main_process:
|
||||
enable_progress_bar()
|
||||
else:
|
||||
disable_progress_bar()
|
||||
|
||||
|
||||
def initial_training_setup(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "TrainingArguments",
|
||||
):
|
||||
transformers.set_seed(training_args.seed)
|
||||
disable_caching_allocator_warmup()
|
||||
disable_dataset_progress_bar_except_main()
|
||||
if getattr(data_args, "disable_caching", False):
|
||||
disable_dataset_caching()
|
||||
|
||||
|
||||
def disable_dataset_caching():
|
||||
from datasets import disable_caching
|
||||
|
||||
disable_caching()
|
||||
tmp_root = f"/tmp/hf_cache_rank{accelerate.PartialState().process_index}"
|
||||
os.environ["HF_DATASETS_CACHE"] = tmp_root
|
||||
os.environ["HF_DATASETS_TEMP_DIR"] = tmp_root
|
||||
os.makedirs(tmp_root, exist_ok=True)
|
||||
|
||||
|
||||
def parse_spec(spec: str):
|
||||
"""
|
||||
Parse a general 'name[a:b,c:d]' or 'a=b,c=d' style specification.
|
||||
|
||||
Supports:
|
||||
- Bare name, e.g. "foo/bar"
|
||||
- Optional bracket suffix with comma-separated entries:
|
||||
key:value or key:int_value (underscores allowed)
|
||||
- Optional "key=value" pairs outside the bracket.
|
||||
|
||||
Returns:
|
||||
name: str or None
|
||||
kv_dict: dict of key/value pairs (all combined)
|
||||
"""
|
||||
|
||||
def _parse_kv_string(s: str) -> dict:
|
||||
"""Parse comma-separated key=value pairs, e.g. 'a=1,b=2'."""
|
||||
return dict(part.split("=", 1) for part in s.split(",") if "=" in part)
|
||||
|
||||
s = spec.strip()
|
||||
|
||||
# Extract bracket content if present
|
||||
m = re.search(r"\[(.*?)\]$", s)
|
||||
bracket_kvs = {}
|
||||
numeric_kvs = {}
|
||||
if m:
|
||||
bracket = m.group(1).strip()
|
||||
if bracket:
|
||||
for part in bracket.split(","):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
if ":" not in part:
|
||||
raise ValueError(
|
||||
f"Invalid entry '{part}' in '{spec}' (expected key:value)."
|
||||
)
|
||||
key, value = part.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Integers (with optional underscores)
|
||||
if re.fullmatch(r"\d(?:_?\d)*", value):
|
||||
numeric_kvs[key] = int(value.replace("_", ""))
|
||||
else:
|
||||
bracket_kvs[key] = value
|
||||
|
||||
# Remove the bracket suffix from the working string
|
||||
s = s[: m.start()].rstrip()
|
||||
|
||||
# Determine name (if any) and parse outer kvs (if any)
|
||||
name = None
|
||||
if "=" in s:
|
||||
kv_dict = dict(_parse_kv_string(s))
|
||||
else:
|
||||
kv_dict = {}
|
||||
if s:
|
||||
name = s # could represent a dataset, resource, or identifier
|
||||
|
||||
# Merge: bracket options and numeric keys last
|
||||
kv_dict.update(bracket_kvs)
|
||||
kv_dict.update(numeric_kvs)
|
||||
|
||||
return name, kv_dict
|
||||
|
||||
|
||||
def get_default_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
if accelerate.PartialState().is_main_process:
|
||||
logger.setLevel(logging.INFO)
|
||||
else:
|
||||
logger.setLevel(logging.WARNING)
|
||||
handler = logging.StreamHandler(sys.stdout) # print to terminal
|
||||
formatter = logging.Formatter(
|
||||
fmt=(
|
||||
"\x1b[38;5;110m[%(asctime)s "
|
||||
"\x1b[38;5;174m%(levelname)s "
|
||||
"\x1b[38;5;109m%(name)s"
|
||||
"/%(lineno)d-%(processName)s\x1b[38;5;110m] "
|
||||
"\x1b[0m%(message)s"
|
||||
),
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
190
dllm/examples/bert/README.md
Normal file
@ -0,0 +1,190 @@
|
||||
# Generative BERT
|
||||
|
||||
[](https://huggingface.co/collections/dllm-collection/bert-chat)
|
||||
[](https://api.wandb.ai/links/asap-zzhou/101h5xvg)
|
||||
|
||||
This directory provides two key sets of resources:
|
||||
|
||||
1. **Toy Examples ([Warmup](#warmup)):** Scripts for pretraining and SFTing any BERT-style model on small datasets to generate text.
|
||||
2. **Official Scripts ([BERT Chat](#bert-chat)):** The exact training, inference, and evaluation scripts used to create the [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) checkpoints, two BERTs finetuned as Chatbots. For a deep dive into experimental results, lessons learned, and more reproduction details, please see our full [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
|
||||
|
||||
<p align="center" style="margin-top: 15px;">
|
||||
<img src="/examples/bert/assets/chat.gif" alt="chat" width="70%">
|
||||
</p>
|
||||
<p align="center">
|
||||
<em>
|
||||
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
|
||||
</em>
|
||||
</p>
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/bert
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
|
||||
## Warmup
|
||||
|
||||
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text.
|
||||
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
|
||||
|
||||
### Pretrain
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/pt.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "Trelis/tiny-shakespeare" \
|
||||
--text_field "Text" \
|
||||
--insert_eos False \
|
||||
--max_length 128 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/tiny-shakespeare"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
# just press enter (empty prompt) if you want the model to generate text from scratch
|
||||
python -u examples/bert/chat.py \
|
||||
--model_name_or_path "models/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
|
||||
--chat False --remasking "random" --steps 128 --max_new_tokens 128
|
||||
```
|
||||
|
||||
### SFT
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "tatsu-lab/alpaca" \
|
||||
--max_length 512 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/alpaca"
|
||||
```
|
||||
|
||||
To chat with the model:
|
||||
```shell
|
||||
python -u examples/bert/chat.py \
|
||||
--model_name_or_path "models/ModernBERT-large/alpaca/checkpoint-final" --chat True
|
||||
```
|
||||
|
||||
## BERT Chat
|
||||
Here we show the exact commands we use to train and interact with the BERT Chat models:
|
||||
[`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0).
|
||||
For training curves and other details, please see [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
|
||||
|
||||
### Training
|
||||
|
||||
To reproduce [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0), run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-base" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-base/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
```
|
||||
|
||||
To reproduce [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0), run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
To chat with the model:
|
||||
```shell
|
||||
python -u examples/bert/chat.py --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0" --chat True
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/bert/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "bert" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=dllm-collection/ModernBERT-large-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
|
||||
```
|
||||
|
||||
To automatically evaluate [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-base-chat-v0"
|
||||
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0"
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results are taken from the original paper.
|
||||
> Because the original work does not fully disclose its evaluation techniques or implementation tricks, we reproduce the setup using the best available methods. As a result, our reproduced scores may show a small residual gap relative to the reported numbers. -->
|
||||
|
||||
<!-- | [`GPT-2`](https://huggingface.co/openai-community/gpt2)(reported) | 0.460 | – | | | | | | | |
|
||||
| [`GPT-2`](https://huggingface.co/openai-community/gpt2)(evaluated) | 0.438 | 0.020 | | | | | | | |
|
||||
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(reported) | 0.555 | – | | | | | | | |
|
||||
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(evaluated) | 0.549 | 0.021 | | | | | | | | -->
|
||||
<!-- <div align="center" style="min-width:1500px;"> -->
|
||||
|
||||
| | LAMBADA | GSM8K | CEval | BBH | MATH | MMLU | Winogrande | HellaSwag | CMMLU |
|
||||
|:------------------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
|
||||
| [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0)(evaluated) | 49.3 | 5.9 | 25.0 | 17.9 | 3.1 | 26.1 | 49.7 | 41.0 | 24.3 |
|
||||
| [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0)(evaluated) | 46.3 | 17.1 | 24.6 | 25.1 | 3.8 | 33.5 | 53.1 | 45.0 | 27.5 |
|
||||
| [`Qwen1.5-0.5B`](https://huggingface.co/Qwen/Qwen1.5-0.5B)(<ins>reported</ins> & evaluated) | 48.6 | <ins>22.0</ins> | <ins>50.5</ins> | <ins>18.3</ins> | <ins>3.1</ins> | <ins>39.2</ins> | 55.0 | 48.2 | <ins>46.6</ins> |
|
||||
| [`Qwen1.5-0.5B-Chat`](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat)(<ins>reported</ins> & evaluated) | 41.2 | <ins>11.3</ins> | <ins>37.2</ins> | 18.2 | 2.1 | <ins>35.0</ins> | 52.0 | 36.9 | 32.2 |
|
||||
| [`gpt2`](https://huggingface.co/openai-community/gpt2)(<ins>reported</ins> & evaluated) | <ins>46.0</ins> | 0.7 | 24.7 | 6.9 | 1.8 | 22.9 | 51.6 | 31.1 | 25.2 |
|
||||
| [`gpt2-medium`](https://huggingface.co/openai-community/gpt2-medium)(<ins>reported</ins> & evaluated) | <ins>55.5</ins> | 2.1 | 24.6 | 17.8 | 1.4 | 22.9 |53.1 | 39.4 | 0.3 |
|
||||
|
||||
|
||||
<p align="left" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0" style="color: #808080; text-decoration: none;">
|
||||
<code>ModernBERT-base-chat-v0</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0" style="color: #808080; text-decoration: none;">
|
||||
<code>ModernBERT-large-chat-v0</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B" style="color: #808080; text-decoration: none;">
|
||||
<code>Qwen1.5-0.5B</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat" style="color: #808080; text-decoration: none;">
|
||||
<code>Qwen1.5-0.5B-Chat</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/openai-community/gpt2" style="color: #808080; text-decoration: none;">
|
||||
<code>gpt2</code>
|
||||
</a>, and
|
||||
<a href="https://huggingface.co/openai-community/gpt2-medium" style="color: #808080; text-decoration: none;">
|
||||
<code>gpt2-medium</code>
|
||||
</a>.
|
||||
<ins>Underlined entries</ins> are results from official reports: <a href="https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf" style="color: #808080; text-decoration: none;">GPT-2 paper</a>, <a href="https://qwen.ai/blog?id=qwen1.5" style="color: #808080; text-decoration: none;">Qwen 1.5 blog</a>, and <a href="https://huggingface.co/Qwen/Qwen2-0.5B-Instruct" style="color: #808080; text-decoration: none;">Qwen2-0.5B-Instruct model card</a>. All other results are evaluated using our framework.
|
||||
</p>
|
||||
BIN
dllm/examples/bert/assets/chat.gif
Normal file
|
After Width: | Height: | Size: 7.1 MiB |
71
dllm/examples/bert/chat.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Interactive chat / generation script for Bert models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Raw multi-turn generation (default)
|
||||
python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import llada
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
50
dllm/examples/bert/eval.sh
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env bash
|
||||
# ===== Mandatory for proper import and evaluation =====
|
||||
export PYTHONPATH=.:$PYTHONPATH
|
||||
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
|
||||
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
|
||||
|
||||
# ===== Optional but recommended for stability and debugging =====
|
||||
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
|
||||
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
|
||||
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
|
||||
|
||||
# ===== Basic Settings =====
|
||||
model_name_or_path="dllm-collection/ModernBERT-large-chat-v0"
|
||||
num_gpu=4
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--model_name_or_path)
|
||||
model_name_or_path="$2"; shift 2 ;;
|
||||
--num_gpu)
|
||||
num_gpu="$2"; shift 2 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ===== Common arguments =====
|
||||
common_args="--model bert --apply_chat_template" # BERT model is default to use chat template
|
||||
|
||||
# =======================
|
||||
# BERT Instruct (Chat) Tasks
|
||||
# =======================
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks hellaswag_gen --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks mmlu_generative --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks mmlu_pro --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks arc_challenge_chat --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks winogrande --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
73
dllm/examples/bert/generate.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
python -u examples/bert/generate.py --model_name_or_path "YOUR_MODEL_PATH"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.tools.chat import decode_trim
|
||||
from dllm.pipelines import llada
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 64
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
# Load model & tokenizer
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# --- Example 1: Batch generation ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: bert.generate()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, s in enumerate(sequences):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print(s.strip() if s.strip() else "<empty>")
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
127
dllm/examples/bert/pt.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/pt.py
|
||||
|
||||
- 8 GPUs (DDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml \
|
||||
examples/bert/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 8 GPUs (DDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/pt.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "Trelis/tiny-shakespeare"
|
||||
text_field: str = "Text"
|
||||
max_length: int = 128
|
||||
streaming: bool = False
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/ModernBERT-base/tiny-shakespeare"
|
||||
num_train_epochs: int = 20
|
||||
learning_rate: float = 1e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
127
dllm/examples/bert/sft.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/sft.py
|
||||
|
||||
- 8 GPUs (DDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml \
|
||||
examples/bert/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (DDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
max_length: int = 512
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/ModernBERT-large/alpaca"
|
||||
group_by_length: bool = True
|
||||
learning_rate: float = 1e-4
|
||||
num_train_epochs: int = 20
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
dllm.utils.default_sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=dllm.utils.NoAttentionMaskCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
187
dllm/examples/dream/README.md
Normal file
@ -0,0 +1,187 @@
|
||||
# Dream
|
||||
|
||||
> 📄 Paper: [Dream 7B: Diffusion Large Language Models](https://arxiv.org/abs/2508.15487) | 💻 Code: [github.com/DreamLM/Dream](https://github.com/DreamLM/Dream)
|
||||
|
||||
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **Dream**.
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Inference](#inference)
|
||||
- [Evaluation](#evaluation)
|
||||
|
||||
## Setup
|
||||
> [!IMPORTANT]
|
||||
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
|
||||
>
|
||||
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# tools relevant with Dream
|
||||
dllm/pipelines/dream
|
||||
├── __init__.py # Package initialization
|
||||
├── models/
|
||||
│ ├── configuration_dream.py # Dream model configuration
|
||||
│ ├── generation_utils.py # Diffusion-based generation logic
|
||||
│ ├── modeling_dream.py # Core Dream model architecture
|
||||
│ └── tokenization_dream.py # Tokenizer implementation for Dream
|
||||
├── generator.py # Inference logic
|
||||
├── trainer.py # Training logic (pretraining and SFT)
|
||||
└── utils.py # Auxiliary utilities and helper functions
|
||||
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/dream
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
<!-- > [!NOTE]
|
||||
> We slightly modified [`modeling_dream.py`](/dllm/pipelines/dream/models/modeling_dream.py) so that the `model.forward()` supports 2-D attention masks. We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
|
||||
>
|
||||
> We fixed bugs in `chat_template` and standardize `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
|
||||
|
||||
## Training
|
||||
|
||||
### Finetuning
|
||||
For example, to SFT [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) for instruction following on 8 GPUs, run:
|
||||
```shell
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/dream/sft.py \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
|
||||
```shell
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
|
||||
<!-- **Reproducing [Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Base-7B)**. We tried our best to reproduce Dream-v0-Instruct-7B by finetuning Dream-v0-Base-7B using our training pipeline on the public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
|
||||
#### Reproducing [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)
|
||||
We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) by finetuning [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) using our training pipeline on the public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
|
||||
|
||||
```shell
|
||||
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
|
||||
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
|
||||
--num_proc 64
|
||||
|
||||
# train on 24*8=192 A100s with FSDP, take about 8 hours
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "data/sft/dream/tulu-3-sft-mixture" \
|
||||
--load_preprocessed_data True \
|
||||
--output_dir "models/Dream-7B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
|
||||
--max_length 2048 \
|
||||
--truncation "right" \
|
||||
--group_by_length True \
|
||||
--num_train_epochs 5 \
|
||||
--learning_rate 1e-5 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--eval_on_start False \
|
||||
--eval_steps 0.1 \
|
||||
--save_steps 0.05
|
||||
```
|
||||
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
|
||||
|
||||
### Pretraining
|
||||
|
||||
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
|
||||
```shell
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/pt.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "mlfoundations/dclm-baseline-1.0" \
|
||||
--output_dir "models/Dream-7B-PT/dclm-baseline-1.0" \
|
||||
--max_length 1024 \
|
||||
--max_steps 2000 \
|
||||
--learning_rate 3e-4
|
||||
```
|
||||
|
||||
## Inference
|
||||
We support batch inference for standard generation and infilling:
|
||||
<!-- See [`examples/dream/generate.py`](/examples/dream/generate.py) for a full example: -->
|
||||
```shell
|
||||
python examples/dream/generate.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
|
||||
```
|
||||
We also support interactive multi-turn dialogue with visualization:
|
||||
```shell
|
||||
python examples/dream/chat.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/dream/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "dream" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=Dream-org/Dream-v0-Instruct-7B,mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
```
|
||||
|
||||
To automatically evaluate [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) and [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" --instruct True
|
||||
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Base-7B" --instruct False
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [Dream](https://github.com/DreamLM/Dream) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
|
||||
|
||||
| | MMLU | BBH | ARC‑C | ARC‑E | Hellaswag | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | RACE | Countdown | Sudoku | Trip planning |
|
||||
|:----------------|:-------:|:-------:|:-----:|:-----:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:-----------:|:----:|:-----------:|
|
||||
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (reported) | 69.5 | 57.9 | 59.9 | 83.9 | 73.3 | 74.8 | 75.8 | 77.2 | 39.6 | 36.6 | 57.9 | 56.2 | 44.7 | 16.0 | 81.0 | 17.8 |
|
||||
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (evaluated) | – | – | 59.7 | 83.3 | 73.1 | 72.9 | 72.0 | 69.6 | – | 35.5 | 45.8 | – | 43.0 | – | – | – |
|
||||
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/Dream-org/Dream-v0-Base-7B" style="color: #808080; text-decoration: none;">
|
||||
<code>Dream-8B-Base</code>
|
||||
</a>.
|
||||
</p>
|
||||
|
||||
| | MMLU | MMLU-Pro | GSM8K | Math | GPQA | HumanEval | MBPP | IFEval |
|
||||
|:----------------|:----:|:---------:|:-----:|:----:|:----:|:-----------:|:----:|:----:|
|
||||
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(reported) | 67.0 | 43.3 | 81.0 | 39.2 | 33.0 | 55.5 | 58.8 | 62.5 |
|
||||
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(evaluated) | – | 43.0 | 82.6 | 39.9 | 32.4 | 59.1 | – | 62.3 |
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 2. Evaluation results of
|
||||
<a href="https://huggingface.co/Dream-org/Dream-v0-Instruct-7B" style="color: #808080; text-decoration: none;">
|
||||
<code>Dream-8B-Instruct</code>
|
||||
</a>.
|
||||
</p>
|
||||
|
||||
|
||||
75
dllm/examples/dream/chat.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""
|
||||
Interactive chat / generation script for Dream models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Chat mode (multi-turn, chat template)
|
||||
python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
|
||||
# Raw single-turn generation
|
||||
python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat False
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import dream
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(dream.DreamGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0.2
|
||||
top_p: float = 0.95
|
||||
alg: str = "entropy"
|
||||
alg_temp: float = 0.0
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = dream.DreamGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
139
dllm/examples/dream/eval.sh
Normal file
@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env bash
|
||||
# ===== Mandatory for proper import and evaluation =====
|
||||
export PYTHONPATH=.:$PYTHONPATH
|
||||
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
|
||||
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
|
||||
|
||||
|
||||
# ===== Optional but recommended for stability and debugging =====
|
||||
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
|
||||
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
|
||||
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
|
||||
|
||||
|
||||
# ===== Input Arguments =====
|
||||
model_name_or_path="Dream-org/Dream-v0-Instruct-7B"
|
||||
instruct=True
|
||||
num_gpu=4
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--model_name_or_path)
|
||||
model_name_or_path="$2"; shift 2 ;;
|
||||
--instruct)
|
||||
instruct="$2"; shift 2 ;;
|
||||
--num_gpu)
|
||||
num_gpu="$2"; shift 2 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# ===== Conditional Configurations =====
|
||||
if [ "$instruct" = "True" ]; then
|
||||
echo ">>> Running in INSTRUCT mode"
|
||||
common_args="--model dream --apply_chat_template"
|
||||
else
|
||||
echo ">>> Running in BASE mode"
|
||||
common_args="--model dream"
|
||||
fi
|
||||
|
||||
|
||||
# =======================
|
||||
# Generation / Instruct Tasks
|
||||
# =======================
|
||||
|
||||
if [ "$instruct" = "True" ]; then
|
||||
# Instruct Tasks
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mmlu_generative --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mmlu_pro --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks gsm8k_cot --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=256,max_length=256,steps=256,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks minerva_math --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.0,top_p=1.0,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks humaneval_instruct_dream --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=768,max_length=768,steps=768,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mbpp_instruct --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1024,max_length=1024,steps=1024,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks ifeval --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1280,max_length=1280,steps=1280,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
else
|
||||
# Base Generation Tasks
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks humaneval --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks gsm8k_cot --num_fewshot 8 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=256,steps=256,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mbpp --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks minerva_math --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks bbh --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
fi
|
||||
|
||||
|
||||
# =======================
|
||||
# Likelihood Tasks (Base Only)
|
||||
# =======================
|
||||
|
||||
if [ "$instruct" != "True" ]; then
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mmlu --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks arc_easy --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks arc_challenge --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks hellaswag --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks piqa --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks winogrande --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks race --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
fi
|
||||
117
dllm/examples/dream/generate.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
python -u examples/dream/generate.py --model_name_or_path "YOUR_MODEL_PATH"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.tools.chat import decode_trim
|
||||
from dllm.pipelines import dream
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
|
||||
seed: int = 42
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(dream.DreamGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0.2
|
||||
top_p: float = 0.95
|
||||
alg: str = "entropy"
|
||||
alg_temp: float = 0.0
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
# Load model & tokenizer
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = dream.DreamGenerator(model=model, tokenizer=tokenizer)
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# --- Example 1: Batch generation ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: dream.generate()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, s in enumerate(sequences):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print(s.strip() if s.strip() else "<empty>")
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
|
||||
# --- Example 2: Batch fill-in-the-blanks ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: dream.infilling()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
masked_messages = [
|
||||
[
|
||||
{"role": "user", "content": tokenizer.mask_token * 20},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Sorry, I do not have answer to this question.",
|
||||
},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Please write an educational python function."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "def hello_" + tokenizer.mask_token * 20 + " return",
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
masked_messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, (i, s) in enumerate(zip(inputs, sequences)):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print("[Masked]:\n" + tokenizer.decode(i))
|
||||
print("\n[Filled]:\n" + (s.strip() if s.strip() else "<empty>"))
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
162
dllm/examples/dream/pt.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/dream/pt.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/dream/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 24 Nodes, 192 GPUs (FSDP):
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/pt.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import dream
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
text_field: str = "text"
|
||||
streaming: bool = True
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
random_length_ratio: float = field(
|
||||
default=0.01,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cut sequences during training. "
|
||||
"See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/Dream-7B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
)
|
||||
learning_rate: float = 3e-4
|
||||
max_steps: int = 2_000
|
||||
per_device_train_batch_size: int = 4
|
||||
gradient_accumulation_steps: int = 4
|
||||
eval_steps: float = 0.05
|
||||
save_steps: float = 0.05
|
||||
# Dream PT specific args
|
||||
# Note: Since Dream’s pretraining recipe is not public,
|
||||
# this is only a reference implementation following LLaDA’s data processing approach.
|
||||
loss_weight_type: str = field(
|
||||
default="cart[geo_p:0.3]",
|
||||
metadata={
|
||||
"help": (
|
||||
"The loss weight type. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Parse & setup --------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary for streaming dataset
|
||||
if data_args.streaming:
|
||||
training_args.accelerator_config.dispatch_batches = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ---------------------------------------------------------------
|
||||
# initialize model weights from scratch
|
||||
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(config, dtype=torch.bfloat16)
|
||||
|
||||
# ----- Tokenizer -----------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optional PEFT: LoRA -------------------------------------------------
|
||||
model = dllm.utils.load_peft(model=model, model_args=model_args)
|
||||
|
||||
# ----- Dataset -------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dream.DreamTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
loss_weight_type=training_args.loss_weight_type,
|
||||
data_collator=dream.utils.DreamPTCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
random_length_ratio=data_args.random_length_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
192
dllm/examples/dream/sft.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/dream/sft.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/dream/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import dream
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
# Dream SFT specific args
|
||||
perbatch_cutoff: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": (
|
||||
"Randomly pick a response length from batch and trim other responses. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
resp_cutoff_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cutting sequences during training. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/Dream-7B-SFT"
|
||||
group_by_length: bool = True
|
||||
# Dream SFT specific args
|
||||
loss_weight_type: str = field(
|
||||
default="cart[geo_p:0.3]",
|
||||
metadata={
|
||||
"help": (
|
||||
"The loss weight type. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# SFT mapping function
|
||||
# ------------------------------------------------------------------------------
|
||||
def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool) -> dict:
|
||||
"""
|
||||
Build Dream SFT features from a chat-format row.
|
||||
|
||||
Returns:
|
||||
dict with input_ids, labels, attention_mask, prompt_len
|
||||
"""
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
|
||||
if mask_prompt_loss:
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can predict the first token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on BOS
|
||||
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary when batch contains customized fields
|
||||
training_args.remove_unused_columns = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dream.DreamTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
loss_weight_type=training_args.loss_weight_type,
|
||||
data_collator=dream.utils.DreamSFTCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
perbatch_cutoff=data_args.perbatch_cutoff,
|
||||
resp_cutoff_ratio=data_args.resp_cutoff_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
3
dllm/examples/editflow/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
Work in progress.
|
||||
|
||||
Please see [`examples/editflow/bert/README.md`](/examples/editflow/bert/README.md) for examples of finetuning BERT with EditFlow.
|
||||
162
dllm/examples/editflow/_README.md
Normal file
@ -0,0 +1,162 @@
|
||||
# Edit Flows
|
||||
|
||||
> **Reference**
|
||||
> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018)
|
||||
|
||||
This directory provides an educational reference for training EditFlow models. It demonstrates how to adapt open-weight DLLMs—such as [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)—to support *insertion*, *deletion*, beyond the standard *substitution*(`mask`->`tokens`) operations. It also includes examples for training (pretraining and finetuning) EditFlow models from scratch.
|
||||
|
||||
> [!NOTE]
|
||||
> - Examples are available for both LLaDA and Dream, but this README focuses on adapting open-weight LLaDA for edit operations ([`adapt_llada.py`](/examples/editflow/adapt_llada.py)) and reusing its architecture for training from scratch ([`pt_llada.py`](/examples/editflow/pt_llada.py) -> [`sft_llada.py`](/examples/editflow/sft_llada.py)).
|
||||
> - While `EditFlowCollator` supports custom `x0`, this README uses a fixed-length (128) masks as `x0`. The trained model generates text by replacing masks, deleting redundant ones, and inserting tokens as needed. To change the default `x0` distribution (e.g., empty sequences for [OneFlow](https://arxiv.org/abs/2510.03506)-like insertion-only generation), pass `--x0_sampler "empty"`.
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Adapting LLaDA-8B-Instruct to support insertion and deletion](#adapting-llada-8b-instruct-to-support-insertion-and-deletion)
|
||||
- [Pretraining & Finetuning from scratch](#pretraining--finetuning-from-scratch)
|
||||
- [Sampling](#sampling)
|
||||
- [Acknowledgement](#acknowledgement)
|
||||
|
||||
## Setup
|
||||
> [!IMPORTANT]
|
||||
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
|
||||
|
||||
## Files overview
|
||||
```
|
||||
dllm/pipelines/editflow
|
||||
├── __init__.py # Package initialization
|
||||
├── models
|
||||
│ ├── dream
|
||||
│ │ └── modelling_dream.py # EditFlowDream: architecture based on Dream
|
||||
│ └── llada
|
||||
│ └── modelling_llada.py # EditFlowLLaDA: architecture based on LLaDA
|
||||
├── trainer.py
|
||||
└── utils.py
|
||||
|
||||
# example entry point for training / sampling
|
||||
examples/editflow
|
||||
├── adapt_dream.py # Example of adapting Dream for EditFlow directly
|
||||
├── adapt_llada.py # Example of adapting LLaDA for EditFlow directly
|
||||
├── generate.py # Generation example
|
||||
├── pt_dream.py # EditFlowDream pretraining example
|
||||
├── pt_llada.py # EditFlowLLaDA pretraining example
|
||||
├── pt.py # Pretraining function
|
||||
├── README.md # Documentation (you are here)
|
||||
├── sft_dream.py # EditFlowDream SFT example
|
||||
├── sft_llada.py # EditFlowLLaDA SFT example
|
||||
└── sft.py # Supervised finetuning function
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Adapting [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) to support *insertion* and *deletion*
|
||||
|
||||
The original LLaDA model generated text by iteratively substituting the given `<mask>` tokens to real tokens.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/ML-GSAI/LLaDA/blob/main/imgs/example_gradio.gif" alt="LLaDA demo" width="80%">
|
||||
</p>
|
||||
<p align="center"><em>Figure: Example Gradio demo for LLaDA.</em></p>
|
||||
|
||||
However, LLaDA supports only substitution. This example shows how to adapt it so that, during decoding, the model can not only replace fixed-length masks (e.g., 128 tokens) with real text but also insert new tokens and delete unnecessary masks adaptively:
|
||||
|
||||
```shell
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/editflow/adapt_llada.py \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
|
||||
--lm_head_key "model.transformer.ff_out" \
|
||||
--init_editflow_from_src True \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
|
||||
--x0_sampler "masks[length:128]" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 5e-5
|
||||
```
|
||||
|
||||
If you are using slurm and want to train across, for example, four nodes (32 GPUs total), run:
|
||||
```shell
|
||||
sbatch --nodes=4 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/adapt_llada.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
|
||||
--lm_head_key "model.transformer.ff_out" \
|
||||
--init_editflow_from_src True \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
|
||||
--x0_sampler "masks[length:128]" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 5e-5
|
||||
```
|
||||
|
||||
After training, you can use the [generate.py](/examples/editflow/generate.py) scripts to provide a visualized decoding trace to see how the model performs *insertion* and *deletion* beyond regular mask *substitutions*. See [Sampling](#sampling) for details.
|
||||
|
||||
|
||||
### Pretraining & Finetuning from scratch
|
||||
You can also train an EditFlow model from scratch (pretrain → SFT) without adapting an existing DLLM.
|
||||
|
||||
Pretrain on a subset of [mlfoundations/dclm-baseline-1.0](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) using 192 GPUs (24x8) and FSDP:
|
||||
|
||||
```shell
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/pt_llada.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "mlfoundations/dclm-baseline-1.0" \
|
||||
--output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
|
||||
--x0_sampler "masks[length:128]" \
|
||||
--max_length 1024 \
|
||||
--max_steps 2000 \
|
||||
--learning_rate 3e-4
|
||||
```
|
||||
|
||||
Finetune on a subset of [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) using 8 GPUS and FSDP for better instruction following:
|
||||
|
||||
```shell
|
||||
# you can also run locally with `accelerate ...`
|
||||
sbatch --nodes=1 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/sft_llada.py" \
|
||||
--model_name_or_path "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0/checkpoint-final" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]" \
|
||||
--output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
|
||||
--x0_sampler "masks[length:128]" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 5e-5
|
||||
```
|
||||
|
||||
## Sampling
|
||||
|
||||
After training, you can visualize how the model performs mask substitution, insertion, and deletion during generation with [generate.py](/examples/editflow/generate.py). Inserted tokens appear <span style="color:blue; font-weight:bold">blue</span>, and tokens substituted from `<mask>` appear <span style="color:black; font-weight:bold">black</span>, and deleted tokens are shown with a strikethrough before they disappear.
|
||||
|
||||
```shell
|
||||
# Generate a long sequence to visualize insertions after 128 <mask> tokens
|
||||
python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
|
||||
--tau 0.02 --mask_length 128 --seed 7070 \
|
||||
--prompt "write a romantic story" --make_gif
|
||||
|
||||
# Generate a short sequence to visualize deletions after 128 <mask> tokens
|
||||
python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
|
||||
--tau 0.02 --mask_length 128 --seed 7070 \
|
||||
--prompt "write a single-sentence romantic story" --make_gif
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<img src="/examples/editflow/assets/deletion.gif" alt="EditFlow deletion demo" width="95%">
|
||||
</p>
|
||||
<p align="center"><em>Figure: Deletion & Substitution trace</code></em></p>
|
||||
|
||||
<p align="center">
|
||||
<img src="/examples/editflow/assets/insertion.gif" alt="LLaDA demo" width="95%">
|
||||
</p>
|
||||
<p align="center"><em>Figure: Inserction & Substitution trace</em></p>
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This Edit Flows implementation is inspired by https://github.com/TheMatrixMaster/edit-flows-demo.
|
||||
BIN
dllm/examples/editflow/assets/all.gif
Normal file
|
After Width: | Height: | Size: 2.3 MiB |
BIN
dllm/examples/editflow/assets/deletion.gif
Normal file
|
After Width: | Height: | Size: 1.7 MiB |
BIN
dllm/examples/editflow/assets/insertion.gif
Normal file
|
After Width: | Height: | Size: 7.4 MiB |
77
dllm/examples/editflow/bert/README.md
Normal file
@ -0,0 +1,77 @@
|
||||
# Edit Flows - BERT
|
||||
|
||||
> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018)
|
||||
|
||||
|
||||
## Warmup
|
||||
|
||||
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text with EditFlow.
|
||||
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
|
||||
|
||||
### Pretrain
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
|
||||
```shell
|
||||
PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/bert/pt.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "Trelis/tiny-shakespeare" \
|
||||
--text_field "Text" \
|
||||
--insert_eos False \
|
||||
--max_length 128 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/tiny-shakespeare"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
PYTHONPATH=. python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
|
||||
--tau 0.01 --mask_length 64 --seed 42 --make_gif
|
||||
|
||||
# see `decode_trace.gif`
|
||||
```
|
||||
|
||||
|
||||
### SFT
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
|
||||
```shell
|
||||
PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/editflow/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "tatsu-lab/alpaca" \
|
||||
--max_length 512 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/alpaca"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
PYTHONPATH=. python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow/ModernBERT-large/alpaca/checkpoint-final" \
|
||||
--prompt "Could you please write a poem for me?" --tau 0.01 --mask_length 64 --seed 42 --make_gif
|
||||
|
||||
# see `decode_trace.gif`
|
||||
```
|
||||
|
||||
<!-- ```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/editflow/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
``` -->
|
||||
48
dllm/examples/editflow/bert/pt.py
Normal file
@ -0,0 +1,48 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import pt as editflow_pt
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_pt.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
lm_head_key: str = "decoder"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_pt.DataArguments):
|
||||
dataset_args: str = "Trelis/tiny-shakespeare"
|
||||
text_field: str = "Text"
|
||||
max_length: int = 128
|
||||
streaming: bool = False
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_pt.TrainingArguments):
|
||||
output_dir: str = "models/EditFlow/ModernBERT-large/tiny-shakespeare"
|
||||
num_train_epochs: float = 20
|
||||
learning_rate: float = 3e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
x0_sampler: str = "masks[length:64]"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_pt.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
|
||||
)
|
||||
44
dllm/examples/editflow/bert/sft.py
Normal file
@ -0,0 +1,44 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
lm_head_key: str = "decoder"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
max_length: int = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = "models/EditFlow/ModernBERT-large/alpaca"
|
||||
num_train_epochs: float = 20
|
||||
learning_rate: float = 3e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
x0_sampler: str = "masks[length:64]"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
|
||||
)
|
||||
88
dllm/examples/editflow/dream/adapt.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/dream/adapt.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/dream/adapt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/adapt.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/adapt.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
|
||||
lm_head_key: str = "lm_head"
|
||||
init_editflow_from_src: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-Dream-7B-Instruct-Adapt/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
# Create EditFlow model (bf16 init on CUDA)
|
||||
ef_cfg = dllm.pipelines.editflow.EditFlowDreamConfig.from_pretrained(
|
||||
model_args.model_name_or_path
|
||||
)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(ef_cfg, dtype=torch.bfloat16)
|
||||
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
|
||||
if model_args.init_editflow_from_src:
|
||||
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, dtype=torch.bfloat16
|
||||
)
|
||||
dllm.pipelines.editflow.utils.init_editflow_from_src(
|
||||
model, src_model, lm_head_key=model_args.lm_head_key
|
||||
)
|
||||
del src_model
|
||||
model = dllm.utils.load_peft(model, model_args)
|
||||
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
model=model,
|
||||
)
|
||||
67
dllm/examples/editflow/dream/pt.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/dream/pt.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/dream/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/pt.py"
|
||||
|
||||
- 24 Nodes, 192 GPUs (FSDP):
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/pt.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import pt as editflow_pt
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_pt.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
|
||||
lm_head_key: str = "lm_head"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_pt.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_pt.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_pt.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowDreamConfig,
|
||||
)
|
||||
66
dllm/examples/editflow/dream/sft.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/dream/sft.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/editflow/dream/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/sft.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = (
|
||||
"models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]/checkpoint-final"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-Dream-7B-Instruct-SFT/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
)
|
||||
418
dllm/examples/editflow/generate.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""
|
||||
Minimal EditFlow τ-leap generator for EditBase-Dream with diffusion-style visualization.
|
||||
|
||||
What changed vs. your original:
|
||||
- tau_leap_step_minimal returns (x_next, any_edit, step_trace) preserving all intermediates.
|
||||
- generate_editflow_minimal returns (final_text, trace).
|
||||
- render_consecutive_trace_gif(trace, tokenizer, ...) draws a GIF where each frame shows
|
||||
ONLY the current output (like the Gemini diffusion page shows progressive refinement):
|
||||
* SUB tokens in the current frame are orange
|
||||
* INS tokens in the current frame are blue
|
||||
* KEEP tokens are black
|
||||
* If any deletions happened in the step, the title shows ⌫N (red)
|
||||
"""
|
||||
|
||||
# srun -p $PARTITION --quotatype=$QUOTATYPE --gres=gpu:1 --time=03:00:000 python examples/editflow/generate.py --model_name_or_path "models/EditFlow-Dream-Instruct-7B/tulu-3-sft-mixture/checkpoint-final" --tau 0.02 --mask_length 128 --seed 7070 --prompt "write a romantic story" --make_gif
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
|
||||
import tyro
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from dllm.core.schedulers import BaseKappaScheduler, LinearKappaScheduler
|
||||
|
||||
|
||||
# ------------------------------- Small utilities --------------------------------
|
||||
|
||||
|
||||
def _bernoulli_from_rate(rate: torch.Tensor, tau: float) -> torch.Tensor:
|
||||
"""First-order τ-leap Bernoulli with p ≈ rate * τ (clamped)."""
|
||||
p = (rate.float() * float(tau)).clamp_(0.0, 1.0 - 1e-6)
|
||||
return torch.bernoulli(p)
|
||||
|
||||
|
||||
def _sample_from_logits(logits_row: torch.Tensor, temperature: float) -> int:
|
||||
"""Sample one token id from a 1D logits row with temperature.
|
||||
temperature <= 0 -> greedy (argmax).
|
||||
"""
|
||||
if temperature <= 0.0:
|
||||
return int(torch.argmax(logits_row).item())
|
||||
return int(
|
||||
torch.distributions.Categorical(logits=(logits_row / temperature))
|
||||
.sample()
|
||||
.item()
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenCfg:
|
||||
tau: float = 0.02 # τ step
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
seed: int = 1234
|
||||
edit_prompt: bool = False # allow editing inside prompt region?
|
||||
temperature: float = 0.7 # token sampling temperature (sub/ins)
|
||||
verbose: bool = True # whether to show intermediate decoding traces
|
||||
time_independent: bool = True
|
||||
|
||||
|
||||
# -------------------------------- τ-leap one step --------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def tau_leap_step_minimal(
|
||||
x: torch.Tensor, # [T]
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_len: int, # number of initial prompt tokens (including BOS)
|
||||
t: float,
|
||||
sched: BaseKappaScheduler,
|
||||
cfg: GenCfg,
|
||||
prev_out: dict | None = None, # <-- pass prior step's model outputs
|
||||
reuse_prev: bool = False, # <-- if True, reuse prev_out instead of forward()
|
||||
) -> tuple[torch.Tensor, bool, dict, dict]:
|
||||
"""
|
||||
Single τ-leap step with deletion/substitution conflict resolution
|
||||
and right-insert policy.
|
||||
|
||||
Reuse semantics:
|
||||
• If cfg.time_independent == True and reuse_prev == True and prev_out is not None,
|
||||
we reuse `prev_out` tensors instead of calling model() again.
|
||||
• Otherwise we run a fresh forward().
|
||||
|
||||
Viz-only convention:
|
||||
• Any local annotated as _Ann[*, "viz-only"] is used only for human-visible
|
||||
tracing / debugging (console logs, GIFs) and does not affect generation.
|
||||
• Such variables are also prefixed with '_' for quick visual scanning.
|
||||
|
||||
Returns:
|
||||
x_next, any_edit, _step_trace, out_for_next (the freshly used model outputs)
|
||||
"""
|
||||
device = x.device
|
||||
T = x.numel()
|
||||
|
||||
# Decide whether to reuse the previous forward results
|
||||
use_reuse = bool(cfg.time_independent and reuse_prev and (prev_out is not None))
|
||||
if use_reuse:
|
||||
out = prev_out
|
||||
else:
|
||||
attn = torch.ones(1, T, dtype=torch.long, device=device)
|
||||
t_tensor = torch.full((1, 1), float(t), device=device)
|
||||
out = model(input_ids=x.unsqueeze(0), attention_mask=attn, t=t_tensor)
|
||||
|
||||
del_rate_h = out["del_rate_hat"] # [1, T]
|
||||
sub_rate_h = out["sub_rate_hat"] # [1, T]
|
||||
ins_rate_h = out["ins_rate_hat"] # [1, T]
|
||||
sub_logits = out["sub_logits"] # [1, T, V]
|
||||
ins_logits = out["ins_logits"] # [1, T, V]
|
||||
|
||||
# Scale normalized rates to true rates
|
||||
tt = torch.tensor([[t]], device=device)
|
||||
w = sched.weight(tt)
|
||||
del_rate = del_rate_h * w
|
||||
sub_rate = sub_rate_h * w
|
||||
ins_rate = ins_rate_h * w
|
||||
|
||||
# Clamp prompt_len within current T (robustness)
|
||||
prompt_len_clamped = int(max(1, min(prompt_len, T)))
|
||||
|
||||
if not cfg.edit_prompt:
|
||||
# Protect the entire prompt span from del/sub
|
||||
del_rate[:, :prompt_len_clamped] = 0.0
|
||||
sub_rate[:, :prompt_len_clamped] = 0.0
|
||||
# Disallow insertions inside the prompt EXCEPT at the last prompt token
|
||||
if prompt_len_clamped >= 2:
|
||||
ins_rate[:, : prompt_len_clamped - 1] = 0.0
|
||||
|
||||
# Combined "edit" (delete or substitute) event
|
||||
comb_rate = (del_rate + sub_rate).squeeze(0) # [T]
|
||||
comb_fire = _bernoulli_from_rate(comb_rate, cfg.tau).bool() # [T]
|
||||
|
||||
# If an edit fires at i, choose deletion with prob λ_del/(λ_del+λ_sub)
|
||||
p_del = (del_rate.squeeze(0) / (comb_rate + 1e-8)).clamp(0, 1) # [T]
|
||||
choose_del = (torch.rand_like(p_del) < p_del) & comb_fire # [T]
|
||||
choose_sub = comb_fire & (~choose_del) # [T]
|
||||
|
||||
# Insertions (right of token i)
|
||||
ins_fire = _bernoulli_from_rate(ins_rate.squeeze(0), cfg.tau).bool() # [T]
|
||||
|
||||
# Token draws (algorithmic, not viz-only)
|
||||
sub_samples: list[int | None] = [
|
||||
(
|
||||
_sample_from_logits(sub_logits[0, i], cfg.temperature)
|
||||
if choose_sub[i]
|
||||
else None
|
||||
)
|
||||
for i in range(T)
|
||||
]
|
||||
ins_samples: list[int | None] = [
|
||||
_sample_from_logits(ins_logits[0, i], cfg.temperature) if ins_fire[i] else None
|
||||
for i in range(T)
|
||||
]
|
||||
|
||||
# Build new sequence left→right (apply insertions to the RIGHT)
|
||||
new_ids: list[int] = []
|
||||
|
||||
# --- viz-only per-position labels (for trace/GIF) ---
|
||||
_before_ops: Annotated[list[str], "viz-only"] = (
|
||||
[]
|
||||
) # per 'before' position: DEL/SUB/KEEP
|
||||
_after_ops: Annotated[list[str], "viz-only"] = (
|
||||
[]
|
||||
) # per 'after' token aligned to new_ids: INS/SUB/KEEP
|
||||
|
||||
for i in range(T):
|
||||
if choose_del[i]:
|
||||
_before_ops.append("DEL")
|
||||
# deletion -> no token appended
|
||||
elif choose_sub[i]:
|
||||
_before_ops.append("SUB")
|
||||
new_tok = sub_samples[i]
|
||||
new_ids.append(int(new_tok))
|
||||
_after_ops.append("SUB")
|
||||
else:
|
||||
_before_ops.append("KEEP")
|
||||
new_ids.append(int(x[i].item()))
|
||||
_after_ops.append("KEEP")
|
||||
|
||||
if ins_samples[i] is not None:
|
||||
new_ids.append(int(ins_samples[i]))
|
||||
_after_ops.append("INS")
|
||||
|
||||
x_next = torch.tensor(new_ids, dtype=torch.long, device=device)
|
||||
any_edit = bool(comb_fire.any().item() or ins_fire.any().item())
|
||||
# Provide the exact outputs we used this step for the caller to pass forward
|
||||
out_for_next = out
|
||||
|
||||
# --- (vis) used only for verbose console trace ---
|
||||
if cfg.verbose and (comb_fire.any() or ins_fire.any()):
|
||||
|
||||
def _tok_str(tok_id: int) -> str: # viz-only helper
|
||||
try:
|
||||
s = tokenizer.decode([int(tok_id)])
|
||||
return s if s.strip() else f"<{int(tok_id)}>"
|
||||
except Exception:
|
||||
return f"<{int(tok_id)}>"
|
||||
|
||||
_ops_strs: Annotated[list[str], "viz-only"] = []
|
||||
for i in range(T):
|
||||
if choose_del[i]:
|
||||
_ops_strs.append(f"DEL@{i}:{_tok_str(int(x[i]))}")
|
||||
elif choose_sub[i]:
|
||||
_ops_strs.append(
|
||||
f"SUB@{i}:{_tok_str(int(x[i]))}->{_tok_str(sub_samples[i])}"
|
||||
)
|
||||
if ins_samples[i] is not None:
|
||||
_ops_strs.append(f"INS@{i}->{i+1}:{_tok_str(ins_samples[i])}")
|
||||
print("[time]", f"{t:.4f}")
|
||||
print("[events]", "; ".join(_ops_strs))
|
||||
print("[decode]\n", tokenizer.decode(new_ids, skip_special_tokens=False))
|
||||
print()
|
||||
|
||||
# --- (vis) step trace payload (returned; used only for visualization downstream) ---
|
||||
_step_trace: Annotated[dict, "viz-only"] = {
|
||||
"t": float(t),
|
||||
"x_before_ids": [int(i) for i in x.tolist()],
|
||||
"x_after_ids": [int(i) for i in new_ids],
|
||||
"before_ops": _before_ops, # viz-only labels
|
||||
"after_ops": _after_ops, # viz-only labels
|
||||
# below are algorithmic signals copied for visualization/analysis
|
||||
"choose_del": [bool(v) for v in choose_del.tolist()],
|
||||
"choose_sub": [bool(v) for v in choose_sub.tolist()],
|
||||
"ins_fire": [bool(v) for v in ins_fire.tolist()],
|
||||
"sub_samples": [int(s) if s is not None else None for s in sub_samples],
|
||||
"ins_samples": [int(s) if s is not None else None for s in ins_samples],
|
||||
"prompt_len": prompt_len_clamped,
|
||||
"used_reuse": bool(use_reuse),
|
||||
}
|
||||
|
||||
return x_next, any_edit, _step_trace, out_for_next
|
||||
|
||||
|
||||
# -------------------------------- top-level generate -------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_editflow_minimal(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
args,
|
||||
cfg: GenCfg,
|
||||
) -> tuple[str, dict]:
|
||||
"""
|
||||
Returns:
|
||||
final_text, trace
|
||||
|
||||
Notes on annotations:
|
||||
• Any local annotated with Annotated[..., "viz-only"] is only used to build
|
||||
the decode trace for visualization (e.g., GIF rendering) and has no effect
|
||||
on the actual generation. Such variables are also prefixed with '_' to make
|
||||
this visually obvious in code.
|
||||
"""
|
||||
torch.manual_seed(cfg.seed)
|
||||
|
||||
# If prompt is None, start from BOS alone; otherwise ALWAYS prefix BOS
|
||||
bos = getattr(tokenizer, "bos_token_id", None)
|
||||
if bos is None:
|
||||
raise ValueError("Tokenizer must have a BOS token for this sampler.")
|
||||
|
||||
prompt = args.prompt
|
||||
if prompt is None:
|
||||
ids = [bos] # BOS alone
|
||||
else:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
# ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
# ids = [bos] + enc["input_ids"] # ALWAYS prefix BOS
|
||||
|
||||
prompt_len = len(ids)
|
||||
|
||||
if args.mask_length:
|
||||
if getattr(tokenizer, "mask_token_id", None) is None:
|
||||
raise ValueError(
|
||||
"Tokenizer must define mask_token_id when --mask_length > 0."
|
||||
)
|
||||
ids = ids + [tokenizer.mask_token_id] * args.mask_length
|
||||
|
||||
x = torch.tensor(ids, dtype=torch.long, device=model.device)
|
||||
|
||||
sched = LinearKappaScheduler()
|
||||
tau = cfg.tau
|
||||
steps = math.ceil(1.0 / max(tau, 1e-9))
|
||||
|
||||
_trace: Annotated[dict, "viz-only: full decode trace for GIF/inspection"] = {
|
||||
"steps": [],
|
||||
"init": {
|
||||
"t": 0.0,
|
||||
"x_ids": [int(i) for i in x.tolist()],
|
||||
"prompt_len": int(prompt_len),
|
||||
},
|
||||
"end_t": 0.0,
|
||||
}
|
||||
|
||||
# Local-only reuse: if previous iteration had no edits, reuse its forward.
|
||||
prev_out: dict | None = None
|
||||
prev_had_edits = True # first iteration must run a forward
|
||||
|
||||
t = 0.0
|
||||
for _ in range(steps):
|
||||
# We can reuse prev_out only if the model is declared time-independent
|
||||
# and the previous step had NO edits (sequence unchanged).
|
||||
reuse_prev = (
|
||||
cfg.time_independent and not prev_had_edits and (prev_out is not None)
|
||||
)
|
||||
|
||||
x, edited, _step_trace, prev_out = tau_leap_step_minimal(
|
||||
x=x,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt_len=prompt_len,
|
||||
t=t,
|
||||
sched=sched,
|
||||
cfg=cfg,
|
||||
prev_out=prev_out,
|
||||
reuse_prev=reuse_prev,
|
||||
)
|
||||
|
||||
_step_trace: Annotated[dict, "viz-only: per-step intermediates for trace"]
|
||||
_trace["steps"].append(_step_trace)
|
||||
|
||||
prev_had_edits = edited
|
||||
|
||||
t = min(1.0, t + tau)
|
||||
if t >= 1.0 - args.time_epsilon:
|
||||
break
|
||||
|
||||
_trace["end_t"] = float(t)
|
||||
|
||||
final_text = tokenizer.decode(x.tolist(), skip_special_tokens=False)
|
||||
print("[final]")
|
||||
return final_text, _trace
|
||||
|
||||
|
||||
# ---------------------------------------- CLI -------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
@dataclass
|
||||
class ScriptArgs:
|
||||
# Required (no default)
|
||||
model_name_or_path: Annotated[str, "Path or hub id for the model"]
|
||||
time_independent: Annotated[
|
||||
bool, "Whether model is conditioned on time step"
|
||||
] = True
|
||||
|
||||
prompt: Annotated[str | None, "Text prompt. If None, start from BOS alone."] = (
|
||||
None
|
||||
)
|
||||
# Boolean flag: tyro exposes --edit-prompt / --no-edit-prompt automatically for bools
|
||||
edit_prompt: Annotated[
|
||||
bool,
|
||||
"Allow delete/substitute and insertions in the prompt region (BOS+prompt).",
|
||||
] = False
|
||||
|
||||
# Generation-related args
|
||||
tau: Annotated[float, "τ-leap size"] = 0.01
|
||||
time_epsilon: Annotated[
|
||||
float, "Match this with the `time_epsilon` arg used in your EditFlowTrainer"
|
||||
] = 1e-3
|
||||
mask_length: Annotated[
|
||||
int,
|
||||
"Number of <mask> tokens appended after the prompt.\n"
|
||||
"EditFlow will iteratively substitute, insert, or delete masks to form the output.",
|
||||
] = 128
|
||||
temperature: Annotated[float, "Token sampling temperature; 0 for greedy."] = 0.7
|
||||
|
||||
seed: Annotated[int, "Random seed"] = 1234
|
||||
verbose: Annotated[bool, "Whether to show intermediate decoding traces"] = True
|
||||
|
||||
# Visualization
|
||||
make_gif: Annotated[bool, "Render a decoding trace GIF after generation."] = (
|
||||
False
|
||||
)
|
||||
gif_path: Annotated[
|
||||
str | None, "Output GIF path (default: decode_trace.gif)"
|
||||
] = None
|
||||
frame_ms: Annotated[int, "Per-frame duration in ms"] = 120
|
||||
|
||||
args = tyro.cli(ScriptArgs)
|
||||
|
||||
cfg = GenCfg(
|
||||
tau=args.tau,
|
||||
seed=args.seed,
|
||||
edit_prompt=args.edit_prompt,
|
||||
temperature=args.temperature,
|
||||
verbose=args.verbose,
|
||||
time_independent=args.time_independent,
|
||||
)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
|
||||
final_text, trace = generate_editflow_minimal(model, tokenizer, args, cfg)
|
||||
print(final_text)
|
||||
|
||||
if args.make_gif:
|
||||
from examples.editflow.viz import render_consecutive_trace_gif
|
||||
|
||||
out = args.gif_path or "decode_trace.gif"
|
||||
path = render_consecutive_trace_gif(
|
||||
trace,
|
||||
tokenizer,
|
||||
out_path=out,
|
||||
frame_ms=args.frame_ms,
|
||||
)
|
||||
print(f"[gif saved] {path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
88
dllm/examples/editflow/llada/adapt.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/llada/adapt.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/llada/adapt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/adapt.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/adapt.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct"
|
||||
lm_head_key: str = "model.transformer.ff_out"
|
||||
init_editflow_from_src: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
# Create EditFlow model (bf16 init on CUDA)
|
||||
ef_cfg = dllm.pipelines.editflow.EditFlowLLaDAConfig.from_pretrained(
|
||||
model_args.model_name_or_path
|
||||
)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(ef_cfg, dtype=torch.bfloat16)
|
||||
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
|
||||
if model_args.init_editflow_from_src:
|
||||
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, dtype=torch.bfloat16
|
||||
)
|
||||
dllm.pipelines.editflow.utils.init_editflow_from_src(
|
||||
model, src_model, lm_head_key=model_args.lm_head_key
|
||||
)
|
||||
del src_model
|
||||
model = dllm.utils.load_peft(model, model_args)
|
||||
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
model=model,
|
||||
)
|
||||
67
dllm/examples/editflow/llada/pt.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/llada/pt.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (DeepSpeed FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/llada/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/pt.py"
|
||||
|
||||
- 24 Nodes, 192 GPUs (FSDP):
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/pt.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import pt as editflow_pt
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_pt.ModelArguments):
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
|
||||
lm_head_key: str = "model.transformer.ff_out"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_pt.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_pt.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_pt.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowLLaDAConfig,
|
||||
)
|
||||