1029 add octuple
This commit is contained in:
@ -96,7 +96,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
|
||||
num_features = config.nn_params.num_features
|
||||
|
||||
# get vocab
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB', 'oct':'MusicTokenVocabOct'}
|
||||
selected_vocab_name = vocab_name[encoding_scheme]
|
||||
|
||||
vocab = getattr(vocab_utils, selected_vocab_name)(
|
||||
@ -477,7 +477,7 @@ class Evaluator:
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
@ -499,7 +499,7 @@ class Evaluator:
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
@ -509,7 +509,7 @@ class Evaluator:
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
|
||||
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=num_target_measures)
|
||||
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||
|
||||
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||
@ -520,7 +520,7 @@ class Evaluator:
|
||||
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
|
||||
@ -103,7 +103,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
# self.prediction_order = net.prediction_order
|
||||
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
|
||||
|
||||
self.attribute2idx_after = {'pitch': 0,
|
||||
'duration': 1,
|
||||
'velocity': 2,
|
||||
@ -113,6 +113,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
'tempo': 6,
|
||||
'instrument': 7}
|
||||
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
|
||||
# if using position attribute, change accordingly
|
||||
if 'position' in self.net.vocab.feature_list:
|
||||
self.attribute2idx_after = {'pitch': 0,
|
||||
'position': 1,
|
||||
'bar': 2,
|
||||
'velocity': 3,
|
||||
'duration': 4,
|
||||
'program': 5,
|
||||
'tempo': 6,
|
||||
'timesig': 7}
|
||||
self.attribute2idx = {'pitch':0, 'position':1, 'bar':2, 'velocity':3, 'duration':4, 'program':5, 'tempo':6, 'timesig':7}
|
||||
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
|
||||
return self.net(input_seq, target, context=context)
|
||||
|
||||
@ -161,10 +172,12 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||
# measure_bool = (condition[:,1] == 1) # measure tokens
|
||||
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||
elif self.net.vocab.encoding_scheme == 'nb':
|
||||
elif self.net.vocab.encoding_scheme == 'nb' or self.net.vocab.encoding_scheme == 'oct':
|
||||
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
|
||||
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||
|
||||
try:
|
||||
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||
except:
|
||||
conditional_input_len = condition.shape[0]
|
||||
if conditional_input_len == 0:
|
||||
conditional_input_len = 50
|
||||
|
||||
@ -262,7 +275,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
# print(self.attribute2idx)
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
if attr not in attr_list:
|
||||
condition_filtered[:, :, idx] = 126336
|
||||
condition_filtered[:, 1:, idx] = 126336
|
||||
# rearange condition_filtered to match prediction order
|
||||
|
||||
cache = LayerIntermediates()
|
||||
@ -286,8 +299,9 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
new_idx = self.attribute2idx_after[attr]
|
||||
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
|
||||
# print("condition_step shape:", condition_step.shape)
|
||||
# print("condition_step shape:", condition_step)
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
|
||||
# print("sampled_token shape:", sampled_token)
|
||||
else:
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
||||
time_end = time.time()
|
||||
|
||||
@ -713,7 +713,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 8,
|
||||
denoising_steps:int = 6,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
@ -1029,7 +1029,338 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
# print("toknes", sampled_token_dict)
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
# print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
elif self.method == 'random':
|
||||
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
|
||||
elif self.method == 'auto-regressive':
|
||||
indices = torch.tensor([[step]], device=logit.device)
|
||||
# undate the masked history
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
|
||||
# skip if all tokens are unmasked
|
||||
if not expand_masked_history.any():
|
||||
# print("All tokens have been unmasked. Ending denoising process.")
|
||||
break
|
||||
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
|
||||
# get final sampled tokens by embedding the unmasked tokens
|
||||
sampled_token_dict = {}
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
# print("Final sampled tokens:")
|
||||
# print(sampled_token_dict)
|
||||
# print(condition_step)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||||
# apply layer norm
|
||||
|
||||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||||
if worst_case: # mask all ,turn into parallel
|
||||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# all is embedding
|
||||
# memory_tensor = self.layer_norm(memory_tensor)
|
||||
# apply feature enricher to memory
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# implement sub decoder cross attention
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# get prob
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
class DiffusionDecoder(SubDecoderClass):
|
||||
def __init__(
|
||||
self,
|
||||
prediction_order:list,
|
||||
vocab:LangTokenVocab,
|
||||
sub_decoder_depth:int,
|
||||
dim:int,
|
||||
heads:int,
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 6,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
'''
|
||||
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
|
||||
As the output of the main decoder is the representation of the whole sequence,
|
||||
it contains richer information which can even decode out sub-tokens in a parallel manner
|
||||
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
|
||||
'''
|
||||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||||
self.sub_decoder_enricher_use = sub_decoder_enricher_use
|
||||
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
|
||||
|
||||
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
|
||||
nn.init.zeros_(self.pos_enc.weight)
|
||||
|
||||
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||||
self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab
|
||||
nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02)
|
||||
self.MASK_idx = MASK_IDX
|
||||
self.denoising_steps = denoising_steps
|
||||
self.eps = eps
|
||||
self.method = method
|
||||
|
||||
self.input_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
if sub_decoder_enricher_use:
|
||||
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||||
causal_mask = generate_SA_mask(len(prediction_order))
|
||||
causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device)
|
||||
self.register_buffer('causal_mask', causal_mask)
|
||||
self.register_buffer('causal_ca_mask', causal_ca_mask)
|
||||
|
||||
# get depth of the sub-decoder
|
||||
if sub_decoder_depth > 1:
|
||||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||||
else:
|
||||
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
if sub_decoder_enricher_use:
|
||||
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
|
||||
# simplified version of the forward process in diffusion model
|
||||
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
|
||||
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
|
||||
b, l = reshaped_input_ids.shape
|
||||
t = torch.rand(b, device=input_ids.device)
|
||||
p_mask = (1 - eps) * t + eps
|
||||
p_mask = p_mask[:, None].repeat(1, l)
|
||||
|
||||
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
|
||||
# 126336 is used for [MASK] token,attention that this token is not in the vocab
|
||||
if mask_idx is not None:
|
||||
noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids)
|
||||
else:
|
||||
noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in
|
||||
return noisy_batch, masked_indices, p_mask
|
||||
|
||||
|
||||
def _apply_window_on_hidden_vec(self, hidden_vec):
|
||||
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
|
||||
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
|
||||
window_size = 1
|
||||
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
|
||||
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
|
||||
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
|
||||
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
|
||||
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
|
||||
return new_hidden_vec
|
||||
|
||||
def _apply_pos_enc(self, tgt):
|
||||
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
|
||||
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
|
||||
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
|
||||
return tgt_pos
|
||||
|
||||
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
|
||||
for _, feature in enumerate(self.prediction_order[:-1]):
|
||||
feature_idx = self.vocab.feature_list.index(feature)
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||||
memory_list.append(feature_emb_reshape)
|
||||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
|
||||
return memory_tensor
|
||||
|
||||
# return a tensor
|
||||
def _get_noisy_tensor(self, target_shape):
|
||||
new_target = torch.zeros(target_shape).to(self.device)
|
||||
# fill all the elements in the tensor with the embedding of the mask token
|
||||
new_target[:, :, :] = self.diffusion_mask_emb
|
||||
return new_target
|
||||
|
||||
# prepare the embedding of the target,
|
||||
def _prepare_embedding(self, memory_list, target):
|
||||
for _, feature in enumerate(self.prediction_order):
|
||||
feature_idx = self.vocab.feature_list.index(feature)
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||||
memory_list.append(feature_emb_reshape)
|
||||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model
|
||||
return memory_tensor
|
||||
|
||||
|
||||
def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True):
|
||||
memory_list = [] # used for key and value in cross attention
|
||||
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
|
||||
if add_BOS is true:
|
||||
if target is not None: # training
|
||||
memory_list.append(BOS_emb)
|
||||
else: # inference
|
||||
memory_list.append(BOS_emb[-1:, :, :])
|
||||
else:
|
||||
pass
|
||||
return memory_list
|
||||
|
||||
def _get_num_transfer_tokens(self, mask_index, steps):
|
||||
'''
|
||||
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
|
||||
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
|
||||
the expected number of tokens transitioned at each step should be consistent.
|
||||
|
||||
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
||||
'''
|
||||
mask_num = mask_index.sum(dim=1,keepdim=True)
|
||||
base = mask_num // steps
|
||||
remainder = mask_num % steps
|
||||
|
||||
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
||||
|
||||
for i in range(mask_num.size(0)):
|
||||
num_transfer_tokens[i, :remainder[i]] += 1
|
||||
|
||||
return num_transfer_tokens
|
||||
|
||||
def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None):
|
||||
sampled_token_dict = {}
|
||||
logits_dict = {}
|
||||
candidate_token_embeddings = {}
|
||||
candidate_token_probs = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
# print("*"*8)
|
||||
logits_list = []
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_list.append(logit)
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
logit = logits_list[idx] # B x T x vocab_siz
|
||||
sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
if step==0 and force_decode:
|
||||
if feature == 'velocity':
|
||||
sampled_token = torch.tensor([2]).to(logit.device)
|
||||
prob = torch.tensor([1.0]).to(logit.device)
|
||||
else:
|
||||
prob = torch.tensor([0.0]).to(logit.device)
|
||||
# print(feature, sampled_token, prob)
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
logits_dict[feature] = logit
|
||||
candidate_token_probs[feature] = prob
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||||
candidate_token_embeddings[feature] = feature_emb_reshape
|
||||
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
|
||||
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
|
||||
# print("sampled_token_dict", sampled_token_dict)
|
||||
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||||
|
||||
# apply window on hidden_vec for enricher
|
||||
if self.sub_decoder_enricher_use:
|
||||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
if bos_hidden_vec is None: # start of generation
|
||||
if target is None:
|
||||
bos_hidden_vec = input_seq_pos
|
||||
else:
|
||||
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
|
||||
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
|
||||
|
||||
else:
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# input_seq_pos = input_seq
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
|
||||
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
# prepare memory
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
# add attribute control here
|
||||
stored_logits_dict = {}
|
||||
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
|
||||
if condition_step is not None:
|
||||
# print("shape of condition_step", condition_step.shape)
|
||||
condition_step = condition_step.reshape((b*t, l))
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
token = condition_step[i][j]
|
||||
if condition_step[i][j] != self.MASK_idx:
|
||||
|
||||
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
|
||||
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
|
||||
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
# with torch.profiler.profile(
|
||||
# activities=[
|
||||
# torch.profiler.ProfilerActivity.CPU,
|
||||
# torch.profiler.ProfilerActivity.CUDA],
|
||||
# record_shapes=True,
|
||||
# profile_memory=True,
|
||||
# with_stack=True
|
||||
# ) as prof:
|
||||
for step in range(self.denoising_steps):
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
|
||||
force_decode=Force_decode,
|
||||
step=step)
|
||||
|
||||
# print("step", step)
|
||||
# print("toknes", sampled_token_dict)
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
# print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
|
||||
@ -22,6 +22,8 @@ class Augmentor:
|
||||
self.chord_idx = self.feature_list.index('chord')
|
||||
|
||||
def _get_shift(self, segment):
|
||||
if self.encoding_scheme == 'oct':
|
||||
return 0
|
||||
# the pitch vocab has ignore token in 0 index
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
pitch_mask = segment != 0
|
||||
|
||||
@ -73,7 +73,7 @@ class VanillaTransformer_compiler():
|
||||
for i in range(len(self.data_list)):
|
||||
tune_in_idx, tune_name = self.data_list[i]
|
||||
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
else:
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
@ -148,7 +148,7 @@ class VanillaTransformer_compiler():
|
||||
for i in range(len(self.data_list)):
|
||||
tune_in_idx, tune_name = self.data_list[i]
|
||||
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
else:
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
|
||||
@ -95,11 +95,6 @@ class TuneCompiler(Dataset):
|
||||
print(f"Error encoding caption for tune {tune_name}: {e}")
|
||||
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
return segment, tensor_mask, tune_name, encoded_caption
|
||||
if self.data_type == 'train':
|
||||
augmented_segment = self.augmentor(segment)
|
||||
return augmented_segment, tensor_mask, tune_name, encoded_caption
|
||||
else:
|
||||
return segment, tensor_mask, tune_name, encoded_caption
|
||||
|
||||
def get_segments_with_tune_idx(self, tune_name, seg_order):
|
||||
'''
|
||||
@ -135,6 +130,7 @@ class IterTuneCompiler(IterableDataset):
|
||||
self.data_type = data_type
|
||||
self.augmentor = augmentor
|
||||
self.eos_token = vocab.eos_token
|
||||
self.vocab = vocab
|
||||
self.compile_function = VanillaTransformer_compiler(
|
||||
data_list=self.data_list,
|
||||
augmentor=self.augmentor,
|
||||
@ -157,7 +153,7 @@ class IterTuneCompiler(IterableDataset):
|
||||
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
except Exception as e:
|
||||
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
if self.data_type == 'train':
|
||||
if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct':
|
||||
segment = self.augmentor(segment)
|
||||
# use input_ids replace tune_name
|
||||
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
import re
|
||||
import os, sys
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from music21 import converter
|
||||
import muspy
|
||||
import miditoolkit
|
||||
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
|
||||
from symusic import Score
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
|
||||
from .midi2audio import FluidSynth
|
||||
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
|
||||
@ -400,5 +404,62 @@ class MidiDecoder4NB(MidiDecoder4REMI):
|
||||
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||
midi_obj.dump(output_path)
|
||||
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||
return midi_obj
|
||||
|
||||
class MidiDecoder4Octuple(MidiDecoder4REMI):
|
||||
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||||
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||||
|
||||
|
||||
|
||||
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
输入:
|
||||
t: torch.Tensor, 形状 (1, N, M)
|
||||
功能:
|
||||
删除包含独立元素 0, 1, 2, 3 的子tensor行
|
||||
返回:
|
||||
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
|
||||
"""
|
||||
if t.dim() != 3:
|
||||
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
|
||||
|
||||
# 构造一个 mask,True 表示该行不包含 0,1,2,3
|
||||
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
|
||||
|
||||
# 判断每一行是否含有这些值
|
||||
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
|
||||
|
||||
# 过滤行并保留 batch 维
|
||||
filtered_tensor = t[0][mask].unsqueeze(0)
|
||||
|
||||
return filtered_tensor
|
||||
|
||||
def __call__(self, generated_output, output_path=None):
|
||||
config = TokenizerConfig(
|
||||
use_time_signatures=True,
|
||||
use_tempos=True,
|
||||
use_velocities=True,
|
||||
use_programs=True,
|
||||
remove_duplicated_notes=True,
|
||||
delete_equal_successive_tempo_changes=True,
|
||||
)
|
||||
config.additional_params["max_bar_embedding"] = 512
|
||||
tokenizer = Octuple(config)
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
|
||||
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
|
||||
print(output_path)
|
||||
try:
|
||||
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
|
||||
tok_seq.dump_midi(output_path)
|
||||
except Exception as e:
|
||||
print(generated_output)
|
||||
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
|
||||
tok_seq = None
|
||||
|
||||
return tok_seq
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
defaults:
|
||||
# - nn_params: nb8_embSum_NMT
|
||||
# - nn_params: remi8
|
||||
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
||||
- nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
- nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||
# - nn_params: nb8_embSum_subPararell
|
||||
# - nn_params: nb8_embSum_diff_t2m_150M
|
||||
@ -15,7 +15,7 @@ defaults:
|
||||
# - nn_params: remi8_main12_head_16_dim512
|
||||
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||
|
||||
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
dataset: Melody # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
captions_path: dataset/midicaps/train_set.json
|
||||
|
||||
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
||||
@ -31,7 +31,7 @@ tau: 0.5
|
||||
|
||||
train_params:
|
||||
device: cuda
|
||||
batch_size: 5
|
||||
batch_size: 10
|
||||
grad_clip: 1.0
|
||||
num_iter: 300000 # total number of iterations
|
||||
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
|
||||
@ -44,7 +44,7 @@ train_params:
|
||||
focal_gamma: 0
|
||||
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||
scheduler : cosinelr
|
||||
initial_lr: 0.0003
|
||||
initial_lr: 0.00001
|
||||
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||
warmup_steps: 2000 #number of warmup steps
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoder
|
||||
model_dropout: 0.2
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 768
|
||||
num_layer: 16
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -29,8 +29,12 @@ def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_pa
|
||||
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
|
||||
}
|
||||
|
||||
if encoding_scheme == 'remi':
|
||||
oct_prediction_order = {
|
||||
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
|
||||
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
|
||||
if encoding_scheme == 'oct':
|
||||
prediction_order = oct_prediction_order[num_features]
|
||||
elif encoding_scheme == 'remi':
|
||||
prediction_order = feature_prediction_order_dict[num_features]
|
||||
elif encoding_scheme == 'cp':
|
||||
if nn_params.get("partial_sequential_prediction", False):
|
||||
@ -239,11 +243,11 @@ class DiffusionLoss4CompoundToken():
|
||||
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
|
||||
train_loss_list.append(training_loss)
|
||||
if valid:
|
||||
if key == 'type':
|
||||
if key == 'type' or key == 'timesig':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
elif key == 'beat':
|
||||
elif key == 'beat' or key == 'position' or key == 'bar':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument':
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
else:
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
|
||||
@ -74,7 +74,7 @@ class LanguageModelTrainer:
|
||||
sampling_threshold: float, # Threshold for sampling decisions
|
||||
sampling_temperature: float, # Temperature for controlling sampling randomness
|
||||
config, # Configuration parameters (contains general, training, and inference settings)
|
||||
model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
):
|
||||
# Save model, optimizer, and other configurations
|
||||
@ -104,7 +104,6 @@ class LanguageModelTrainer:
|
||||
checkpoint = torch.load(model_checkpoint, map_location='cpu')
|
||||
# print state dict keys
|
||||
print("Loading model checkpoint from", model_checkpoint)
|
||||
print("Checkpoint keys:", checkpoint['model'].keys())
|
||||
if isinstance(self.model, DDP):
|
||||
self.model.module.load_state_dict(checkpoint['model'], strict=False)
|
||||
else:
|
||||
@ -902,9 +901,9 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
|
||||
correct_guess_by_feature = defaultdict(int)
|
||||
num_tokens_by_feature = defaultdict(int)
|
||||
for idx, key in enumerate(self.vocab.feature_list):
|
||||
if key == 'type':
|
||||
if key == 'type' or key == 'timesig' :
|
||||
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None)
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument':
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
|
||||
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
|
||||
elif key == 'beat':
|
||||
# NB's beat vocab has Ignore and CONTI token
|
||||
|
||||
Reference in New Issue
Block a user