1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -0,0 +1,71 @@
"""
Interactive chat / generation script for Bert models.
Examples
--------
# Raw multi-turn generation (default)
python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
"""
import sys
from dataclasses import dataclass
import transformers
import dllm
from dllm.pipelines import llada
from dllm.tools.chat import multi_turn_chat, single_turn_generate
@dataclass
class ScriptArguments:
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
seed: int = 42
chat: bool = True
visualize: bool = True
def __post_init__(self):
# same base-path resolution logic as in generate.py
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(llada.LLaDAGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
block_length: int = 32
temperature: float = 0.0
remasking: str = "low_confidence"
def main():
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
if script_args.chat:
multi_turn_chat(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
else:
print("\nSingle-turn generation (no chat template).")
single_turn_generate(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nInterrupted. Bye!")
sys.exit(0)