1127 update to latest
This commit is contained in:
71
dllm/examples/bert/chat.py
Normal file
71
dllm/examples/bert/chat.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Interactive chat / generation script for Bert models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Raw multi-turn generation (default)
|
||||
python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import llada
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
Reference in New Issue
Block a user