1013 update

This commit is contained in:
FelixChan
2025-10-13 17:56:36 +08:00
parent d077e3210e
commit d6b68ef90b
17 changed files with 815 additions and 70 deletions

View File

@ -1358,6 +1358,7 @@ class Attention(Module):
dim_latent_kv = None,
latent_rope_subheads = None,
onnxable = False,
use_gated_attention = False, # https://arxiv.org/abs/2505.06708
attend_sdp_kwargs: dict = dict(
enable_flash = True,
enable_math = True,
@ -1387,6 +1388,7 @@ class Attention(Module):
k_dim = dim_head * kv_heads
v_dim = value_dim_head * kv_heads
out_dim = value_dim_head * heads
gated_dim = out_dim
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
# for eventually supporting multi-latent attention (MLA)
@ -1447,7 +1449,8 @@ class Attention(Module):
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, out_dim)
# self.to_v_gate = nn.Linear(dim, out_dim)
self.to_v_gate = nn.Linear(dim_kv_input, gated_dim)
self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 10)

View File

@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
def top_p_sampling(logits, thres=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
@ -84,7 +84,7 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
# temporarily apply the sampling method to logits
logits = logits / temperature
# logits = add_gumbel_noise(logits, temperature)
if sampling_method == "top_p":
modified_logits = top_p_sampling(logits, thres=threshold)
elif sampling_method == "typical":

View File

@ -530,6 +530,62 @@ class Melody(SymbolicMusicDataset):
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class msmidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class IrishMan(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
@ -648,62 +704,62 @@ class ariamidi(SymbolicMusicDataset):
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class gigamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
# class gigamidi(SymbolicMusicDataset):
# def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
# for_evaluation: bool = False):
# super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
# for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
# def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
# '''
# Irregular tunes are removed from the dataset for better generation quality
# It includes tunes that are not quantized properly, mostly theay are expressive performance data
# '''
# print("preprocessed tune_in_idx data is being loaded")
# tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
# if self.debug:
# tune_in_idx_list = tune_in_idx_list[:5000]
# tune_in_idx_dict = OrderedDict()
# len_tunes = OrderedDict()
# file_name_list = []
# with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
# irregular_tunes = json.load(f)
# for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# if tune_in_idx_file.stem in irregular_tunes:
# continue
# tune_in_idx = np.load(tune_in_idx_file)['arr_0']
# tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
# len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
# file_name_list.append(tune_in_idx_file.stem)
# print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
# return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
# def _get_split_list_from_tune_in_idx(self, ratio, seed):
# '''
# As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
# '''
# shuffled_tune_names = list(self.tune_in_idx.keys())
# song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
# song_dict = {}
# for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
# if song not in song_dict:
# song_dict[song] = []
# song_dict[song].append(orig_song)
# unique_song_names = list(song_dict.keys())
# random.seed(seed)
# random.shuffle(unique_song_names)
# num_train = int(len(unique_song_names)*ratio)
# num_valid = int(len(unique_song_names)*(1-ratio)/2)
# train_names = []
# valid_names = []
# test_names = []
# for song_name in unique_song_names[:num_train]:
# train_names.extend(song_dict[song_name])
# for song_name in unique_song_names[num_train:num_train+num_valid]:
# valid_names.extend(song_dict[song_name])
# for song_name in unique_song_names[num_train+num_valid:]:
# test_names.extend(song_dict[song_name])
# return train_names, valid_names, test_names, shuffled_tune_names
class ariamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
@ -788,6 +844,9 @@ class gigamidi(SymbolicMusicDataset):
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
if "drums-only" in tune_in_idx_file.stem:
print(f"skipping {tune_in_idx_file.stem} as it is a drums-only file")
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)

View File

@ -11,7 +11,7 @@ License: MIT, see the LICENSE file
__all__ = ['FluidSynth']
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
DEFAULT_SOUND_FONT = 'Alex_GM.sf2'
DEFAULT_SAMPLE_RATE = 48000
DEFAULT_GAIN = 0.05
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"

View File

@ -2,7 +2,8 @@ defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
- nn_params: nb8_embSum_diff_t2m_150M_pretrainingv2
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
- nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
# - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M
@ -14,7 +15,7 @@ defaults:
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: LakhClean # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
@ -30,20 +31,20 @@ tau: 0.5
train_params:
device: cuda
batch_size: 3
batch_size: 5
grad_clip: 1.0
num_iter: 300000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
iterations_per_training_cycle: 10 # number of iterations for logging training loss
iterations_per_validation_cycle: 5000 # number of iterations for validation process
iterations_per_validation_cycle: 3000 # number of iterations for validation process
input_length: 3072 # input sequence length3072
# you can use focal loss, it it's not used, set focal_gamma to 0
focal_alpha: 1
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.00005
initial_lr: 0.0004
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps

View File

@ -5,13 +5,13 @@ model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 20
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewFinetunningDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 1024
num_layer: 32
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 1024
num_layer: 32
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -729,9 +729,8 @@ class XtransformerNewPretrainingDecoder(nn.Module):
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
# shift_tokens = 1,
# attn_qk_norm = True,
# attn_qk_norm_dim_scale = True
attn_gate_values = True,
attn_qk_norm = True,
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
@ -758,7 +757,7 @@ class XtransformerNewPretrainingDecoder(nn.Module):
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
@ -906,6 +905,102 @@ class XtransformerFinetuningDecoder(nn.Module):
else:
context = context_embedding
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
# cut to only return the seq part
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
# cut to only return the seq part
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec, intermediates
else:
# cut to only return the seq part
hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec
class XtransformerNewFinetunningDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
if dim != 768:
self.text_project = nn.Linear(768, dim) # assuming T5 base hidden size is 768
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
use_rmsnorm=True,
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
attn_gate_values = True,
attn_qk_norm = True,
) # add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
context = context_embedding
if hasattr(self, 'text_project'):
context = self.text_project(context)
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference