1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -103,7 +103,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
super().__init__()
self.net = net
# self.prediction_order = net.prediction_order
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
self.attribute2idx_after = {'pitch': 0,
'duration': 1,
'velocity': 2,
@ -113,6 +113,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
'tempo': 6,
'instrument': 7}
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
# if using position attribute, change accordingly
if 'position' in self.net.vocab.feature_list:
self.attribute2idx_after = {'pitch': 0,
'position': 1,
'bar': 2,
'velocity': 3,
'duration': 4,
'program': 5,
'tempo': 6,
'timesig': 7}
self.attribute2idx = {'pitch':0, 'position':1, 'bar':2, 'velocity':3, 'duration':4, 'program':5, 'tempo':6, 'timesig':7}
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
return self.net(input_seq, target, context=context)
@ -161,10 +172,12 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
# measure_bool = (condition[:,1] == 1) # measure tokens
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
elif self.net.vocab.encoding_scheme == 'nb':
elif self.net.vocab.encoding_scheme == 'nb' or self.net.vocab.encoding_scheme == 'oct':
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
try:
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
except:
conditional_input_len = condition.shape[0]
if conditional_input_len == 0:
conditional_input_len = 50
@ -262,7 +275,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
# print(self.attribute2idx)
for attr, idx in self.attribute2idx.items():
if attr not in attr_list:
condition_filtered[:, :, idx] = 126336
condition_filtered[:, 1:, idx] = 126336
# rearange condition_filtered to match prediction order
cache = LayerIntermediates()
@ -286,8 +299,9 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
for attr, idx in self.attribute2idx.items():
new_idx = self.attribute2idx_after[attr]
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
# print("condition_step shape:", condition_step.shape)
# print("condition_step shape:", condition_step)
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
# print("sampled_token shape:", sampled_token)
else:
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
time_end = time.time()