1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -96,7 +96,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
num_features = config.nn_params.num_features
# get vocab
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB', 'oct':'MusicTokenVocabOct'}
selected_vocab_name = vocab_name[encoding_scheme]
vocab = getattr(vocab_utils, selected_vocab_name)(
@ -477,7 +477,7 @@ class Evaluator:
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
@ -499,7 +499,7 @@ class Evaluator:
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
@ -509,7 +509,7 @@ class Evaluator:
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=num_target_measures)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
@ -520,7 +520,7 @@ class Evaluator:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)

View File

@ -103,7 +103,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
super().__init__()
self.net = net
# self.prediction_order = net.prediction_order
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
self.attribute2idx_after = {'pitch': 0,
'duration': 1,
'velocity': 2,
@ -113,6 +113,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
'tempo': 6,
'instrument': 7}
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
# if using position attribute, change accordingly
if 'position' in self.net.vocab.feature_list:
self.attribute2idx_after = {'pitch': 0,
'position': 1,
'bar': 2,
'velocity': 3,
'duration': 4,
'program': 5,
'tempo': 6,
'timesig': 7}
self.attribute2idx = {'pitch':0, 'position':1, 'bar':2, 'velocity':3, 'duration':4, 'program':5, 'tempo':6, 'timesig':7}
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
return self.net(input_seq, target, context=context)
@ -161,10 +172,12 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
# measure_bool = (condition[:,1] == 1) # measure tokens
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
elif self.net.vocab.encoding_scheme == 'nb':
elif self.net.vocab.encoding_scheme == 'nb' or self.net.vocab.encoding_scheme == 'oct':
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
try:
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
except:
conditional_input_len = condition.shape[0]
if conditional_input_len == 0:
conditional_input_len = 50
@ -262,7 +275,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
# print(self.attribute2idx)
for attr, idx in self.attribute2idx.items():
if attr not in attr_list:
condition_filtered[:, :, idx] = 126336
condition_filtered[:, 1:, idx] = 126336
# rearange condition_filtered to match prediction order
cache = LayerIntermediates()
@ -286,8 +299,9 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
for attr, idx in self.attribute2idx.items():
new_idx = self.attribute2idx_after[attr]
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
# print("condition_step shape:", condition_step.shape)
# print("condition_step shape:", condition_step)
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
# print("sampled_token shape:", sampled_token)
else:
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
time_end = time.time()

View File

@ -713,7 +713,7 @@ class DiffusionDecoder(SubDecoderClass):
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 8,
denoising_steps:int = 6,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
@ -1029,7 +1029,338 @@ class DiffusionDecoder(SubDecoderClass):
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
print("stacked_logits_probs", stacked_logits_probs.clone())
# print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
elif self.method == 'random':
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
elif self.method == 'auto-regressive':
indices = torch.tensor([[step]], device=logit.device)
# undate the masked history
for i in range(b*t):
for j in range(l):
if j in indices[i]:
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
# skip if all tokens are unmasked
if not expand_masked_history.any():
# print("All tokens have been unmasked. Ending denoising process.")
break
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
# get final sampled tokens by embedding the unmasked tokens
sampled_token_dict = {}
for idx, feature in enumerate(self.prediction_order):
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
sampled_token_dict[feature] = sampled_token
# print("Final sampled tokens:")
# print(sampled_token_dict)
# print(condition_step)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# all is embedding
# memory_tensor = self.layer_norm(memory_tensor)
# apply feature enricher to memory
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# implement sub decoder cross attention
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
class DiffusionDecoder(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 6,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
'''
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
As the output of the main decoder is the representation of the whole sequence,
it contains richer information which can even decode out sub-tokens in a parallel manner
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
self.sub_decoder_enricher_use = sub_decoder_enricher_use
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
nn.init.zeros_(self.pos_enc.weight)
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab
nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02)
self.MASK_idx = MASK_IDX
self.denoising_steps = denoising_steps
self.eps = eps
self.method = method
self.input_norm = nn.LayerNorm(dim)
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
causal_mask = generate_SA_mask(len(prediction_order))
causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device)
self.register_buffer('causal_mask', causal_mask)
self.register_buffer('causal_ca_mask', causal_ca_mask)
# get depth of the sub-decoder
if sub_decoder_depth > 1:
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
else:
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
# simplified version of the forward process in diffusion model
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
b, l = reshaped_input_ids.shape
t = torch.rand(b, device=input_ids.device)
p_mask = (1 - eps) * t + eps
p_mask = p_mask[:, None].repeat(1, l)
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
# 126336 is used for [MASK] token,attention that this token is not in the vocab
if mask_idx is not None:
noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids)
else:
noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in
return noisy_batch, masked_indices, p_mask
def _apply_window_on_hidden_vec(self, hidden_vec):
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
window_size = 1
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
return new_hidden_vec
def _apply_pos_enc(self, tgt):
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
return tgt_pos
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
for _, feature in enumerate(self.prediction_order[:-1]):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
return memory_tensor
# return a tensor
def _get_noisy_tensor(self, target_shape):
new_target = torch.zeros(target_shape).to(self.device)
# fill all the elements in the tensor with the embedding of the mask token
new_target[:, :, :] = self.diffusion_mask_emb
return new_target
# prepare the embedding of the target,
def _prepare_embedding(self, memory_list, target):
for _, feature in enumerate(self.prediction_order):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model
return memory_tensor
def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True):
memory_list = [] # used for key and value in cross attention
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
if add_BOS is true:
if target is not None: # training
memory_list.append(BOS_emb)
else: # inference
memory_list.append(BOS_emb[-1:, :, :])
else:
pass
return memory_list
def _get_num_transfer_tokens(self, mask_index, steps):
'''
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
the expected number of tokens transitioned at each step should be consistent.
This function is designed to precompute the number of tokens that need to be transitioned at each step.
'''
mask_num = mask_index.sum(dim=1,keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens
def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None):
sampled_token_dict = {}
logits_dict = {}
candidate_token_embeddings = {}
candidate_token_probs = {}
b,t,d = hidden_vec.shape # B x T x d_model
# print("*"*8)
logits_list = []
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_list.append(logit)
for idx, feature in enumerate(self.prediction_order):
logit = logits_list[idx] # B x T x vocab_siz
sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if step==0 and force_decode:
if feature == 'velocity':
sampled_token = torch.tensor([2]).to(logit.device)
prob = torch.tensor([1.0]).to(logit.device)
else:
prob = torch.tensor([0.0]).to(logit.device)
# print(feature, sampled_token, prob)
sampled_token_dict[feature] = sampled_token
logits_dict[feature] = logit
candidate_token_probs[feature] = prob
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
candidate_token_embeddings[feature] = feature_emb_reshape
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
# print("sampled_token_dict", sampled_token_dict)
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
if bos_hidden_vec is None: # start of generation
if target is None:
bos_hidden_vec = input_seq_pos
else:
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
else:
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = input_seq
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
# add attribute control here
stored_logits_dict = {}
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
if condition_step is not None:
# print("shape of condition_step", condition_step.shape)
condition_step = condition_step.reshape((b*t, l))
for i in range(b*t):
for j in range(l):
token = condition_step[i][j]
if condition_step[i][j] != self.MASK_idx:
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
# with torch.profiler.profile(
# activities=[
# torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA],
# record_shapes=True,
# profile_memory=True,
# with_stack=True
# ) as prof:
for step in range(self.denoising_steps):
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
force_decode=Force_decode,
step=step)
# print("step", step)
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
# print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)

View File

@ -22,6 +22,8 @@ class Augmentor:
self.chord_idx = self.feature_list.index('chord')
def _get_shift(self, segment):
if self.encoding_scheme == 'oct':
return 0
# the pitch vocab has ignore token in 0 index
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
pitch_mask = segment != 0

View File

@ -73,7 +73,7 @@ class VanillaTransformer_compiler():
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
@ -148,7 +148,7 @@ class VanillaTransformer_compiler():
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)

View File

@ -95,11 +95,6 @@ class TuneCompiler(Dataset):
print(f"Error encoding caption for tune {tune_name}: {e}")
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
return segment, tensor_mask, tune_name, encoded_caption
if self.data_type == 'train':
augmented_segment = self.augmentor(segment)
return augmented_segment, tensor_mask, tune_name, encoded_caption
else:
return segment, tensor_mask, tune_name, encoded_caption
def get_segments_with_tune_idx(self, tune_name, seg_order):
'''
@ -135,6 +130,7 @@ class IterTuneCompiler(IterableDataset):
self.data_type = data_type
self.augmentor = augmentor
self.eos_token = vocab.eos_token
self.vocab = vocab
self.compile_function = VanillaTransformer_compiler(
data_list=self.data_list,
augmentor=self.augmentor,
@ -157,7 +153,7 @@ class IterTuneCompiler(IterableDataset):
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
except Exception as e:
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
if self.data_type == 'train':
if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct':
segment = self.augmentor(segment)
# use input_ids replace tune_name
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption

View File

@ -1,13 +1,17 @@
import re
import os, sys
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
import torch
from music21 import converter
import muspy
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
from symusic import Score
from miditok import Octuple, TokenizerConfig
from .midi2audio import FluidSynth
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
@ -400,5 +404,62 @@ class MidiDecoder4NB(MidiDecoder4REMI):
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4Octuple(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
"""
输入:
t: torch.Tensor, 形状 (1, N, M)
功能:
删除包含独立元素 0, 1, 2, 3 的子tensor行
返回:
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
"""
if t.dim() != 3:
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
# 构造一个 maskTrue 表示该行不包含 0,1,2,3
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
# 判断每一行是否含有这些值
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
# 过滤行并保留 batch 维
filtered_tensor = t[0][mask].unsqueeze(0)
return filtered_tensor
def __call__(self, generated_output, output_path=None):
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
print(output_path)
try:
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
tok_seq.dump_midi(output_path)
except Exception as e:
print(generated_output)
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
tok_seq = None
return tok_seq

View File

@ -1,8 +1,8 @@
defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
- nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
- nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
# - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M
@ -15,7 +15,7 @@ defaults:
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
dataset: Melody # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
@ -31,7 +31,7 @@ tau: 0.5
train_params:
device: cuda
batch_size: 5
batch_size: 10
grad_clip: 1.0
num_iter: 300000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
@ -44,7 +44,7 @@ train_params:
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.0003
initial_lr: 0.00001
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps

View File

@ -0,0 +1,19 @@
encoding_scheme: oct
num_features: 8
vocab_name: MusicTokenVocabOct
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -29,8 +29,12 @@ def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_pa
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
}
if encoding_scheme == 'remi':
oct_prediction_order = {
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
if encoding_scheme == 'oct':
prediction_order = oct_prediction_order[num_features]
elif encoding_scheme == 'remi':
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'cp':
if nn_params.get("partial_sequential_prediction", False):
@ -239,11 +243,11 @@ class DiffusionLoss4CompoundToken():
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
train_loss_list.append(training_loss)
if valid:
if key == 'type':
if key == 'type' or key == 'timesig':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'beat':
elif key == 'beat' or key == 'position' or key == 'bar':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'chord' or key == 'tempo' or key == 'instrument':
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])

View File

@ -74,7 +74,7 @@ class LanguageModelTrainer:
sampling_threshold: float, # Threshold for sampling decisions
sampling_temperature: float, # Temperature for controlling sampling randomness
config, # Configuration parameters (contains general, training, and inference settings)
model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
):
# Save model, optimizer, and other configurations
@ -104,7 +104,6 @@ class LanguageModelTrainer:
checkpoint = torch.load(model_checkpoint, map_location='cpu')
# print state dict keys
print("Loading model checkpoint from", model_checkpoint)
print("Checkpoint keys:", checkpoint['model'].keys())
if isinstance(self.model, DDP):
self.model.module.load_state_dict(checkpoint['model'], strict=False)
else:
@ -902,9 +901,9 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
correct_guess_by_feature = defaultdict(int)
num_tokens_by_feature = defaultdict(int)
for idx, key in enumerate(self.vocab.feature_list):
if key == 'type':
if key == 'type' or key == 'timesig' :
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None)
elif key == 'chord' or key == 'tempo' or key == 'instrument':
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
elif key == 'beat':
# NB's beat vocab has Ignore and CONTI token

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MIDI 预处理脚本(并行版)
功能:
1. 使用 miditok 的 Octuple 分词器。
2. 限制 MIDI 文件时长在 8~2000 秒。
3. 缺失 tempo 时默认 120 BPM缺失 time signature 时默认 4/4。
4. 保存 vocab.json。
5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。
"""
import os
import glob
import struct
import numpy as np
from multiprocessing import RLock
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from symusic import Score
from miditok import Octuple, TokenizerConfig
lock = RLock()
def convert_event_dicts(dict_list):
"""
将 event 词表列表按顺序转换为结构化输出
输入: list[dict]
每个 dict 对应一个类别,按固定顺序排列:
0: Pitch/PitchDrum
1: Position
2: Bar
(+ Optional) Velocity
(+ Optional) Duration
(+ Optional) Program
(+ Optional) Tempo
(+ Optional) TimeSignature
输出示例:
{
"pitch": {"0": 0, "1": "Pitch_60", ...},
"position": {...},
...
}
"""
keys_order = [
"pitch", "position", "bar",
"velocity", "duration", "program",
"tempo", "timesig"
]
result = {}
for i, d in enumerate(dict_list):
if i >= len(keys_order):
break # 超出定义范围的忽略
category = keys_order[i]
result[category] = {str(v): k for k, v in d.items()}
return result
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
try:
score = Score(midi_path, ttype="tick")
if(not (8 <= (duration := score.to('second').end()) <= 2000)):
with lock:
print(f" × 时长不符合要求:{midi_path} -> {duration}s")
return
# 分词
tok_seq = tokenizer(score)
token_ids = tok_seq.ids
# add sos token at the beginning
vocab = tokenizer.vocab
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
token_ids.insert(0, sos_token)
token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2
# 保存单个 npz 文件
filename = os.path.splitext(os.path.basename(midi_path))[0]
save_path = os.path.join(output_dir, f"{filename}.npz")
np.savez_compressed(save_path, np.array(token_ids))
except Exception as e:
with lock:
print(f" × 处理文件时出错:{midi_path} -> {e}")
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
# === 1. 初始化分词器并保存词表 ===
print("初始化分词器 Octuple...")
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
vocab = tokenizer.vocab
vocab_structured = convert_event_dicts(vocab)
with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f:
import json
json.dump(vocab_structured, f, ensure_ascii=False, indent=4)
# === 2. 创建输出目录 ===
os.makedirs(output_dir, exist_ok=True)
# === 3. 收集 MIDI 文件 ===
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
midi_paths = list(midi_paths)
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
# === 4. 并行处理 ===
results = []
with ProcessPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
for future in tqdm(as_completed(futures), total=len(futures)):
res = future.result()
if res:
results.append(res)
# === 5. 汇总结果 ===
print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。")
if __name__ == "__main__":
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
dataset_name = midi_directory.split("/")[-1]
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
output_dir = tuneidx_prefix
preprocess_midi_directory(midi_directory, output_dir)

View File

@ -0,0 +1,14 @@
import numpy as np
# 读取 npz 文件
data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True)
# 查看保存的键
print(data.files)
# 输出:['filename', 'sequence']
# 访问数据
sequence = data["arr_0"]
print("token 序列长度:", len(sequence))
print("前 20 个 token", sequence[:20])

View File

@ -1,5 +1,6 @@
import pickle
from pathlib import Path
from re import L
from typing import Union
from multiprocessing import Pool, cpu_count
from collections import defaultdict
@ -58,7 +59,7 @@ class LangTokenVocab:
if in_vocab_file_path is not None:
with open(in_vocab_file_path, 'r') as f:
idx2event_temp = json.load(f)
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb' or self.encoding_scheme == 'oct':
for key in idx2event_temp.keys():
idx2event_temp[key] = {int(idx):tok for idx, tok in idx2event_temp[key].items()}
elif self.encoding_scheme == 'remi':
@ -71,13 +72,18 @@ class LangTokenVocab:
# Extracts features depending on the number of features chosen (4, 5, 7, 8).
def _get_features(self):
if self.encoding_scheme != 'oct':
feature_args = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
self.feature_list = feature_args[self.num_features]
else:
feature_args = {
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
self.feature_list = feature_args[self.num_features]
# Saves the current vocabulary to a specified JSON path.
def save_vocab(self, json_path):
with open(json_path, 'w') as f:
@ -93,13 +99,17 @@ class LangTokenVocab:
self.sos_token = [self.event2idx['SOS_None']]
self.eos_token = [[self.event2idx['EOS_None']]]
else:
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
else: # oct
self.sos_token = [[self.event2idx['pitch']['BOS_None']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['pitch']['EOS_None']] + [0] * (self.num_features - 1)]
# Generates vocabularies by either loading from a file or creating them based on the event data.
def _get_vocab(self, event_data, unique_vocabs=None):
# make new vocab from given event_data
if event_data is not None:
if event_data is not None and self.encoding_scheme != 'oct':
unique_char_list = list(set([f'{event["name"]}_{event["value"]}' for tune_path in event_data for event in pickle.load(open(tune_path, 'rb'))]))
unique_vocabs = sorted(unique_char_list)
unique_vocabs.remove('SOS_None')
@ -119,6 +129,7 @@ class LangTokenVocab:
# load premade vocab
else:
idx2event = unique_vocabs
print(idx2event)
event2idx = {tok : int(idx) for idx, tok in unique_vocabs.items()}
return idx2event, event2idx
@ -393,3 +404,46 @@ class MusicTokenVocabNB(MusicTokenVocabCP):
unique_vocabs.insert(4, 'SSN')
unique_vocabs.insert(5, 'SNN')
return unique_vocabs
class MusicTokenVocabOct(LangTokenVocab):
def __init__(
self,
in_vocab_file_path:Union[Path, None],
event_data: list,
encoding_scheme: str,
num_features: int
):
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
def _get_vocab(self, event_data, unique_vocabs=None):
if event_data is not None:
# Create vocab mappings (event2idx, idx2event) from the provided event data
print('start to get unique vocab')
event2idx = {}
idx2event = {}
unique_vocabs = defaultdict(set)
# Use multiprocessing to extract unique vocabularies for each event
with Pool(16) as p:
results = p.starmap(self._mp_get_unique_vocab, tqdm([(tune, self.feature_list) for tune in event_data]))
# Combine results from different processes
for result in results:
for key in self.feature_list:
unique_vocabs[key].update(result[key])
# Process each feature type
for key in self.feature_list:
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
# Create event2idx and idx2event mappings for each feature
event2idx[key] = {tok: int(idx) for idx, tok in enumerate(unique_vocabs[key])}
idx2event[key] = {int(idx): tok for idx, tok in enumerate(unique_vocabs[key])}
return idx2event, event2idx
else:
# If no event data, simply map unique vocab to indexes
event2idx = {}
for key in self.feature_list:
event2idx[key] = {tok: int(idx) for idx, tok in unique_vocabs[key].items()}
return unique_vocabs, event2idx
def get_vocab_size(self):
# Return the size of the vocabulary for each feature
return {key: len(self.idx2event[key]) for key in self.feature_list}

View File

@ -33,7 +33,9 @@ def get_argument_parser():
parser.add_argument(
"-attr_list",
type=str,
default="beat,duration",
# default="beat,duration,,instrument,tempo",
default="pitch",
# default='bar,position,velocity,duration,program,tempo,timesig',
help="attribute list for attribute-controlled generation",
)
parser.add_argument(
@ -69,7 +71,7 @@ def get_argument_parser():
parser.add_argument(
"-num_target_measure",
type=int,
default=4,
default=128,
help="number of target measures for conditioned generation",
)
parser.add_argument(
@ -86,13 +88,13 @@ def get_argument_parser():
parser.add_argument(
"-num_processes",
type=int,
default=1,
default=4,
help="number of processes to use",
)
parser.add_argument(
"-gpu_ids",
type=str,
default="1,2,3,5",
default="0,1,2,3,5",
help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
)
parser.add_argument(
@ -203,6 +205,7 @@ def attr_conditioned_worker(process_idx, gpu_id, args):
# end_idx = min(chunk_size, len(selected_data))
# data_slice = selected_data[start_idx:end_idx]
data_slice = selected_data
print("data_slice length:", len(data_slice))
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@ -129,6 +129,6 @@ def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv",
if __name__ == "__main__":
dir_a = "wandb/run-20251015_154556-f0pj3ys3/cond_4m_top_p_t0.99_temp1.25/process_2_batch_23"
dir_a = "wandb/run-20251027_161354-f9j1mwp2/uncond_min_p_t0.05_temp1.25_epochch8"
dir_b = "dataset/Melody"
batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6)

View File

@ -8,9 +8,6 @@ import torch
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group
from accelerate import Accelerator
from accelerate.utils import set_seed
import wandb
import hydra
from hydra.core.hydra_config import HydraConfig
@ -20,6 +17,8 @@ from omegaconf import DictConfig, OmegaConf
from accelerate import Accelerator
from accelerate.utils import set_seed
from miditok import Octuple, TokenizerConfig
from Amadeus.symbolic_encoding import data_utils, decoding_utils
from Amadeus.symbolic_encoding.data_utils import get_emb_total_size
from Amadeus import model_zoo, trainer_accelerate as trainer
@ -99,7 +98,7 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
out_vocab_path = Path(save_dir) / f'vocab_{dataset_name}_{encoding_scheme}{num_features}.json'
# get vocab
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB', 'oct':'MusicTokenVocabOct'}
selected_vocab_name = vocab_name[encoding_scheme]
vocab = getattr(vocab_utils, selected_vocab_name)(
@ -159,7 +158,7 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
focal_gamma = config.train_params.focal_gamma
if encoding_scheme == 'remi':
loss_fn = NLLLoss4REMI(focal_alpha=focal_alpha, focal_gamma=focal_gamma)
elif encoding_scheme in ['cp', 'nb']:
elif encoding_scheme in ['cp', 'nb', 'oct']:
if config.use_diff is False:
loss_fn = NLLLoss4CompoundToken(feature_list=symbolic_dataset.vocab.feature_list, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
else:
@ -181,11 +180,11 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
in_beat_resolution = in_beat_resolution_dict[dataset_name]
except KeyError:
in_beat_resolution = 4
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
midi_decoder = getattr(decoding_utils, midi_decoder_dict[encoding_scheme])(vocab=symbolic_dataset.vocab, in_beat_resolution=in_beat_resolution, dataset_name=dataset_name)
# Select trainer class based on encoding scheme
trainer_option_dict = {'remi': 'LanguageModelTrainer4REMI', 'cp': 'LanguageModelTrainer4CompoundToken', 'nb':'LanguageModelTrainer4CompoundToken'}
trainer_option_dict = {'remi': 'LanguageModelTrainer4REMI', 'cp': 'LanguageModelTrainer4CompoundToken', 'nb':'LanguageModelTrainer4CompoundToken', 'oct':'LanguageModelTrainer4CompoundToken'}
trainer_option = trainer_option_dict[encoding_scheme]
sampling_method = None
sampling_threshold = 0.99

442
vocab.json Normal file
View File

@ -0,0 +1,442 @@
{
"config": {
"pitch_range": [
21,
109
],
"beat_res": {
"0_4": 8,
"4_12": 4
},
"num_velocities": 32,
"remove_duplicated_notes": false,
"encode_ids_split": "bar",
"special_tokens": [
"PAD_None",
"BOS_None",
"EOS_None",
"MASK_None"
],
"use_velocities": true,
"use_note_duration_programs": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
124,
125,
126,
127,
-1
],
"use_chords": false,
"use_rests": false,
"use_tempos": true,
"use_time_signatures": true,
"use_sustain_pedals": false,
"use_pitch_bends": false,
"use_programs": false,
"use_pitch_intervals": false,
"use_pitchdrum_tokens": true,
"default_note_duration": 0.5,
"programs": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
124,
125,
126,
127,
-1
],
"one_token_stream_for_programs": false,
"program_changes": false,
"beat_res_rest": {
"0_1": 8,
"1_2": 4,
"2_12": 2
},
"chord_maps": {
"min": [
0,
3,
7
],
"maj": [
0,
4,
7
],
"dim": [
0,
3,
6
],
"aug": [
0,
4,
8
],
"sus2": [
0,
2,
7
],
"sus4": [
0,
5,
7
],
"7dom": [
0,
4,
7,
10
],
"7min": [
0,
3,
7,
10
],
"7maj": [
0,
4,
7,
11
],
"7halfdim": [
0,
3,
6,
10
],
"7dim": [
0,
3,
6,
9
],
"7aug": [
0,
4,
8,
11
],
"9maj": [
0,
4,
7,
10,
14
],
"9min": [
0,
4,
7,
10,
13
]
},
"chord_tokens_with_root_note": false,
"chord_unknown": null,
"num_tempos": 32,
"tempo_range": [
40,
250
],
"log_tempos": false,
"delete_equal_successive_tempo_changes": true,
"time_signature_range": {
"8": [
3,
12,
6
],
"4": [
5,
6,
3,
2,
1,
4
]
},
"delete_equal_successive_time_sig_changes": false,
"sustain_pedal_duration": false,
"pitch_bend_range": [
-8192,
8191,
32
],
"max_pitch_interval": 16,
"pitch_intervals_max_time_dist": 1,
"drums_pitch_range": [
27,
88
],
"ac_polyphony_track": false,
"ac_polyphony_bar": false,
"ac_polyphony_min": 1,
"ac_polyphony_max": 6,
"ac_pitch_class_bar": false,
"ac_note_density_track": false,
"ac_note_density_track_min": 0,
"ac_note_density_track_max": 18,
"ac_note_density_bar": false,
"ac_note_density_bar_max": 18,
"ac_note_duration_bar": false,
"ac_note_duration_track": false,
"ac_repetition_track": false,
"ac_repetition_track_num_bins": 10,
"ac_repetition_track_num_consec_bars": 4,
"additional_params": {
"max_bar_embedding": 60
}
},
"tokenization": "Octuple",
"miditok_version": "3.0.6.post1",
"symusic_version": "0.5.8",
"hf_tokenizers_version": "0.21.2"
}