1127 update to latest
This commit is contained in:
144
dllm/scripts/tests/test_attention_mask.py
Normal file
144
dllm/scripts/tests/test_attention_mask.py
Normal file
@ -0,0 +1,144 @@
|
||||
"""
|
||||
LLaDA / MoE / Dream / RND attention mask invariance tests (compact version)
|
||||
"""
|
||||
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import dllm
|
||||
import pytest
|
||||
|
||||
ERROR_THRESHOLD = 1e-3
|
||||
|
||||
|
||||
def _cuda_cleanup():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
# Reclaim interprocess memory blocks (useful after large model del)
|
||||
try:
|
||||
torch.cuda.ipc_collect()
|
||||
except Exception:
|
||||
# Not all PyTorch builds expose ipc_collect on all platforms
|
||||
pass
|
||||
|
||||
|
||||
def _forward_variants(model):
|
||||
"""
|
||||
Run the 5 padding/mask variants and return tensors sliced to the 'real' tokens [1,2,3,4].
|
||||
Returns dict: {'A','B','C','D','E'} each [1, 4, H]
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
# A: no padding
|
||||
a_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
a_mask = torch.tensor([[1, 1, 1, 1]], device=device)
|
||||
|
||||
# B: left-pad a 0
|
||||
b_ids = torch.tensor([[0, 1, 2, 3, 4]], device=device)
|
||||
b_mask = torch.tensor([[0, 1, 1, 1, 1]], device=device)
|
||||
|
||||
# C: right-pad a 0
|
||||
c_ids = torch.tensor([[1, 2, 3, 4, 0]], device=device)
|
||||
c_mask = torch.tensor([[1, 1, 1, 1, 0]], device=device)
|
||||
|
||||
# D: same as A but attention_mask=None
|
||||
d_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
d_mask = None
|
||||
|
||||
# E: same as A but omit attention_mask entirely
|
||||
e_ids = torch.tensor([[1, 2, 3, 4]], device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
out_A = model(input_ids=a_ids, attention_mask=a_mask).logits # [1,4,H]
|
||||
out_B = model(input_ids=b_ids, attention_mask=b_mask).logits[:, 1:] # [1,4,H]
|
||||
out_C = model(input_ids=c_ids, attention_mask=c_mask).logits[:, :-1] # [1,4,H]
|
||||
out_D = model(input_ids=d_ids, attention_mask=d_mask).logits # [1,4,H]
|
||||
out_E = model(input_ids=e_ids).logits # [1,4,H]
|
||||
|
||||
return {"A": out_A, "B": out_B, "C": out_C, "D": out_D, "E": out_E}
|
||||
|
||||
|
||||
def _assert_invariance(outs: dict, tag: str):
|
||||
ref = outs["A"]
|
||||
for k in ("B", "C", "D", "E"):
|
||||
assert torch.allclose(
|
||||
ref, outs[k], atol=ERROR_THRESHOLD, rtol=ERROR_THRESHOLD
|
||||
), f"[{tag}] Mismatch A vs {k}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo, attn_impl, human_name",
|
||||
[
|
||||
("GSAI-ML/LLaDA-8B-Base", None, "LLaDA Base"),
|
||||
("inclusionAI/LLaDA-MoE-7B-A1B-Base", None, "LLaDA MoE"),
|
||||
("Dream-org/Dream-v0-Base-7B", None, "Dream Base"),
|
||||
("radicalnumerics/RND1-Base-0910", None, "RND Base (native)"),
|
||||
("radicalnumerics/RND1-Base-0910", "sdpa", "RND Base (SDPA)"),
|
||||
],
|
||||
)
|
||||
def test_attention_mask_invariance(repo, attn_impl, human_name):
|
||||
"""
|
||||
For each model/backend:
|
||||
1) Check padding/mask invariance across A..E on the 'real' tokens.
|
||||
2) Print a ✅ message for debug visibility (pytest still enforces assertions).
|
||||
"""
|
||||
model_path = dllm.utils.resolve_with_base_env(repo, "BASE_MODELS_DIR")
|
||||
|
||||
if attn_impl is None:
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
model_path, dtype=torch.float32, device_map="auto"
|
||||
).eval()
|
||||
else:
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
model_path, attn_implementation=attn_impl
|
||||
)
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
model_path, config=config, dtype=torch.float32, device_map="auto"
|
||||
).eval()
|
||||
|
||||
outs = _forward_variants(model)
|
||||
_assert_invariance(outs, human_name)
|
||||
|
||||
print(f"✅ {human_name} attention mask invariance passed within {ERROR_THRESHOLD}.")
|
||||
del model
|
||||
gc.collect()
|
||||
_cuda_cleanup()
|
||||
|
||||
|
||||
def test_rnd_native_vs_sdpa_equivalence():
|
||||
"""
|
||||
Verify RND (native attention) and RND (SDPA) produce equivalent logits on the
|
||||
same real tokens across A..E variants.
|
||||
"""
|
||||
repo = "radicalnumerics/RND1-Base-0910"
|
||||
model_path = dllm.utils.resolve_with_base_env(repo, "BASE_MODELS_DIR")
|
||||
|
||||
# native
|
||||
model_native = transformers.AutoModel.from_pretrained(
|
||||
model_path, dtype=torch.float32, device_map="auto"
|
||||
).eval()
|
||||
|
||||
# sdpa
|
||||
config_sdpa = transformers.AutoConfig.from_pretrained(
|
||||
model_path, attn_implementation="sdpa"
|
||||
)
|
||||
model_sdpa = transformers.AutoModel.from_pretrained(
|
||||
model_path, config=config_sdpa, dtype=torch.float32, device_map="auto"
|
||||
).eval()
|
||||
|
||||
outs_native = _forward_variants(model_native) # expects helper from your file
|
||||
outs_sdpa = _forward_variants(model_sdpa)
|
||||
|
||||
for k in ("A", "B", "C", "D", "E"):
|
||||
assert torch.allclose(
|
||||
outs_native[k], outs_sdpa[k], atol=ERROR_THRESHOLD, rtol=ERROR_THRESHOLD
|
||||
), f"[RND cross-backend] native vs SDPA mismatch on {k}"
|
||||
|
||||
print(f"✅ RND native vs SDPA equivalence passed within {ERROR_THRESHOLD}.")
|
||||
# Explicitly drop model references
|
||||
del model_native
|
||||
del model_sdpa
|
||||
# Collect Python garbage and release CUDA caches
|
||||
gc.collect()
|
||||
_cuda_cleanup()
|
||||
0
dllm/scripts/tests/test_dream_generation.py
Normal file
0
dllm/scripts/tests/test_dream_generation.py
Normal file
Reference in New Issue
Block a user