1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -96,7 +96,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
num_features = config.nn_params.num_features
# get vocab
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB', 'oct':'MusicTokenVocabOct'}
selected_vocab_name = vocab_name[encoding_scheme]
vocab = getattr(vocab_utils, selected_vocab_name)(
@ -477,7 +477,7 @@ class Evaluator:
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
@ -499,7 +499,7 @@ class Evaluator:
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
@ -509,7 +509,7 @@ class Evaluator:
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=num_target_measures)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
@ -520,7 +520,7 @@ class Evaluator:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)

View File

@ -103,7 +103,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
super().__init__()
self.net = net
# self.prediction_order = net.prediction_order
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
self.attribute2idx_after = {'pitch': 0,
'duration': 1,
'velocity': 2,
@ -113,6 +113,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
'tempo': 6,
'instrument': 7}
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
# if using position attribute, change accordingly
if 'position' in self.net.vocab.feature_list:
self.attribute2idx_after = {'pitch': 0,
'position': 1,
'bar': 2,
'velocity': 3,
'duration': 4,
'program': 5,
'tempo': 6,
'timesig': 7}
self.attribute2idx = {'pitch':0, 'position':1, 'bar':2, 'velocity':3, 'duration':4, 'program':5, 'tempo':6, 'timesig':7}
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
return self.net(input_seq, target, context=context)
@ -161,10 +172,12 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
# measure_bool = (condition[:,1] == 1) # measure tokens
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
elif self.net.vocab.encoding_scheme == 'nb':
elif self.net.vocab.encoding_scheme == 'nb' or self.net.vocab.encoding_scheme == 'oct':
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
try:
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
except:
conditional_input_len = condition.shape[0]
if conditional_input_len == 0:
conditional_input_len = 50
@ -262,7 +275,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
# print(self.attribute2idx)
for attr, idx in self.attribute2idx.items():
if attr not in attr_list:
condition_filtered[:, :, idx] = 126336
condition_filtered[:, 1:, idx] = 126336
# rearange condition_filtered to match prediction order
cache = LayerIntermediates()
@ -286,8 +299,9 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
for attr, idx in self.attribute2idx.items():
new_idx = self.attribute2idx_after[attr]
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
# print("condition_step shape:", condition_step.shape)
# print("condition_step shape:", condition_step)
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
# print("sampled_token shape:", sampled_token)
else:
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
time_end = time.time()

View File

@ -713,7 +713,7 @@ class DiffusionDecoder(SubDecoderClass):
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 8,
denoising_steps:int = 6,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
@ -1029,7 +1029,338 @@ class DiffusionDecoder(SubDecoderClass):
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
print("stacked_logits_probs", stacked_logits_probs.clone())
# print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
elif self.method == 'random':
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
elif self.method == 'auto-regressive':
indices = torch.tensor([[step]], device=logit.device)
# undate the masked history
for i in range(b*t):
for j in range(l):
if j in indices[i]:
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
# skip if all tokens are unmasked
if not expand_masked_history.any():
# print("All tokens have been unmasked. Ending denoising process.")
break
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
# get final sampled tokens by embedding the unmasked tokens
sampled_token_dict = {}
for idx, feature in enumerate(self.prediction_order):
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
sampled_token_dict[feature] = sampled_token
# print("Final sampled tokens:")
# print(sampled_token_dict)
# print(condition_step)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# all is embedding
# memory_tensor = self.layer_norm(memory_tensor)
# apply feature enricher to memory
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# implement sub decoder cross attention
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
class DiffusionDecoder(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 6,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
'''
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
As the output of the main decoder is the representation of the whole sequence,
it contains richer information which can even decode out sub-tokens in a parallel manner
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
self.sub_decoder_enricher_use = sub_decoder_enricher_use
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
nn.init.zeros_(self.pos_enc.weight)
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab
nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02)
self.MASK_idx = MASK_IDX
self.denoising_steps = denoising_steps
self.eps = eps
self.method = method
self.input_norm = nn.LayerNorm(dim)
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
causal_mask = generate_SA_mask(len(prediction_order))
causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device)
self.register_buffer('causal_mask', causal_mask)
self.register_buffer('causal_ca_mask', causal_ca_mask)
# get depth of the sub-decoder
if sub_decoder_depth > 1:
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
else:
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
# simplified version of the forward process in diffusion model
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
b, l = reshaped_input_ids.shape
t = torch.rand(b, device=input_ids.device)
p_mask = (1 - eps) * t + eps
p_mask = p_mask[:, None].repeat(1, l)
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
# 126336 is used for [MASK] token,attention that this token is not in the vocab
if mask_idx is not None:
noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids)
else:
noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in
return noisy_batch, masked_indices, p_mask
def _apply_window_on_hidden_vec(self, hidden_vec):
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
window_size = 1
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
return new_hidden_vec
def _apply_pos_enc(self, tgt):
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
return tgt_pos
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
for _, feature in enumerate(self.prediction_order[:-1]):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
return memory_tensor
# return a tensor
def _get_noisy_tensor(self, target_shape):
new_target = torch.zeros(target_shape).to(self.device)
# fill all the elements in the tensor with the embedding of the mask token
new_target[:, :, :] = self.diffusion_mask_emb
return new_target
# prepare the embedding of the target,
def _prepare_embedding(self, memory_list, target):
for _, feature in enumerate(self.prediction_order):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model
return memory_tensor
def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True):
memory_list = [] # used for key and value in cross attention
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
if add_BOS is true:
if target is not None: # training
memory_list.append(BOS_emb)
else: # inference
memory_list.append(BOS_emb[-1:, :, :])
else:
pass
return memory_list
def _get_num_transfer_tokens(self, mask_index, steps):
'''
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
the expected number of tokens transitioned at each step should be consistent.
This function is designed to precompute the number of tokens that need to be transitioned at each step.
'''
mask_num = mask_index.sum(dim=1,keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens
def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None):
sampled_token_dict = {}
logits_dict = {}
candidate_token_embeddings = {}
candidate_token_probs = {}
b,t,d = hidden_vec.shape # B x T x d_model
# print("*"*8)
logits_list = []
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_list.append(logit)
for idx, feature in enumerate(self.prediction_order):
logit = logits_list[idx] # B x T x vocab_siz
sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if step==0 and force_decode:
if feature == 'velocity':
sampled_token = torch.tensor([2]).to(logit.device)
prob = torch.tensor([1.0]).to(logit.device)
else:
prob = torch.tensor([0.0]).to(logit.device)
# print(feature, sampled_token, prob)
sampled_token_dict[feature] = sampled_token
logits_dict[feature] = logit
candidate_token_probs[feature] = prob
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
candidate_token_embeddings[feature] = feature_emb_reshape
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
# print("sampled_token_dict", sampled_token_dict)
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
if bos_hidden_vec is None: # start of generation
if target is None:
bos_hidden_vec = input_seq_pos
else:
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
else:
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = input_seq
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
# add attribute control here
stored_logits_dict = {}
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
if condition_step is not None:
# print("shape of condition_step", condition_step.shape)
condition_step = condition_step.reshape((b*t, l))
for i in range(b*t):
for j in range(l):
token = condition_step[i][j]
if condition_step[i][j] != self.MASK_idx:
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
# with torch.profiler.profile(
# activities=[
# torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA],
# record_shapes=True,
# profile_memory=True,
# with_stack=True
# ) as prof:
for step in range(self.denoising_steps):
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
force_decode=Force_decode,
step=step)
# print("step", step)
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
# print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)

View File

@ -22,6 +22,8 @@ class Augmentor:
self.chord_idx = self.feature_list.index('chord')
def _get_shift(self, segment):
if self.encoding_scheme == 'oct':
return 0
# the pitch vocab has ignore token in 0 index
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
pitch_mask = segment != 0

View File

@ -73,7 +73,7 @@ class VanillaTransformer_compiler():
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
@ -148,7 +148,7 @@ class VanillaTransformer_compiler():
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)

View File

@ -95,11 +95,6 @@ class TuneCompiler(Dataset):
print(f"Error encoding caption for tune {tune_name}: {e}")
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
return segment, tensor_mask, tune_name, encoded_caption
if self.data_type == 'train':
augmented_segment = self.augmentor(segment)
return augmented_segment, tensor_mask, tune_name, encoded_caption
else:
return segment, tensor_mask, tune_name, encoded_caption
def get_segments_with_tune_idx(self, tune_name, seg_order):
'''
@ -135,6 +130,7 @@ class IterTuneCompiler(IterableDataset):
self.data_type = data_type
self.augmentor = augmentor
self.eos_token = vocab.eos_token
self.vocab = vocab
self.compile_function = VanillaTransformer_compiler(
data_list=self.data_list,
augmentor=self.augmentor,
@ -157,7 +153,7 @@ class IterTuneCompiler(IterableDataset):
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
except Exception as e:
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
if self.data_type == 'train':
if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct':
segment = self.augmentor(segment)
# use input_ids replace tune_name
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption

View File

@ -1,13 +1,17 @@
import re
import os, sys
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
import torch
from music21 import converter
import muspy
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
from symusic import Score
from miditok import Octuple, TokenizerConfig
from .midi2audio import FluidSynth
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
@ -400,5 +404,62 @@ class MidiDecoder4NB(MidiDecoder4REMI):
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4Octuple(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
"""
输入:
t: torch.Tensor, 形状 (1, N, M)
功能:
删除包含独立元素 0, 1, 2, 3 的子tensor行
返回:
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
"""
if t.dim() != 3:
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
# 构造一个 maskTrue 表示该行不包含 0,1,2,3
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
# 判断每一行是否含有这些值
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
# 过滤行并保留 batch 维
filtered_tensor = t[0][mask].unsqueeze(0)
return filtered_tensor
def __call__(self, generated_output, output_path=None):
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
print(output_path)
try:
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
tok_seq.dump_midi(output_path)
except Exception as e:
print(generated_output)
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
tok_seq = None
return tok_seq

View File

@ -1,8 +1,8 @@
defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
- nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
- nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
# - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M
@ -15,7 +15,7 @@ defaults:
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
dataset: Melody # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
@ -31,7 +31,7 @@ tau: 0.5
train_params:
device: cuda
batch_size: 5
batch_size: 10
grad_clip: 1.0
num_iter: 300000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
@ -44,7 +44,7 @@ train_params:
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.0003
initial_lr: 0.00001
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps

View File

@ -0,0 +1,19 @@
encoding_scheme: oct
num_features: 8
vocab_name: MusicTokenVocabOct
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -29,8 +29,12 @@ def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_pa
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
}
if encoding_scheme == 'remi':
oct_prediction_order = {
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
if encoding_scheme == 'oct':
prediction_order = oct_prediction_order[num_features]
elif encoding_scheme == 'remi':
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'cp':
if nn_params.get("partial_sequential_prediction", False):
@ -239,11 +243,11 @@ class DiffusionLoss4CompoundToken():
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
train_loss_list.append(training_loss)
if valid:
if key == 'type':
if key == 'type' or key == 'timesig':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'beat':
elif key == 'beat' or key == 'position' or key == 'bar':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'chord' or key == 'tempo' or key == 'instrument':
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])

View File

@ -74,7 +74,7 @@ class LanguageModelTrainer:
sampling_threshold: float, # Threshold for sampling decisions
sampling_temperature: float, # Temperature for controlling sampling randomness
config, # Configuration parameters (contains general, training, and inference settings)
model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
):
# Save model, optimizer, and other configurations
@ -104,7 +104,6 @@ class LanguageModelTrainer:
checkpoint = torch.load(model_checkpoint, map_location='cpu')
# print state dict keys
print("Loading model checkpoint from", model_checkpoint)
print("Checkpoint keys:", checkpoint['model'].keys())
if isinstance(self.model, DDP):
self.model.module.load_state_dict(checkpoint['model'], strict=False)
else:
@ -902,9 +901,9 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
correct_guess_by_feature = defaultdict(int)
num_tokens_by_feature = defaultdict(int)
for idx, key in enumerate(self.vocab.feature_list):
if key == 'type':
if key == 'type' or key == 'timesig' :
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None)
elif key == 'chord' or key == 'tempo' or key == 'instrument':
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
elif key == 'beat':
# NB's beat vocab has Ignore and CONTI token