1029 add octuple
This commit is contained in:
@ -103,7 +103,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
# self.prediction_order = net.prediction_order
|
||||
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
|
||||
|
||||
self.attribute2idx_after = {'pitch': 0,
|
||||
'duration': 1,
|
||||
'velocity': 2,
|
||||
@ -113,6 +113,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
'tempo': 6,
|
||||
'instrument': 7}
|
||||
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
|
||||
# if using position attribute, change accordingly
|
||||
if 'position' in self.net.vocab.feature_list:
|
||||
self.attribute2idx_after = {'pitch': 0,
|
||||
'position': 1,
|
||||
'bar': 2,
|
||||
'velocity': 3,
|
||||
'duration': 4,
|
||||
'program': 5,
|
||||
'tempo': 6,
|
||||
'timesig': 7}
|
||||
self.attribute2idx = {'pitch':0, 'position':1, 'bar':2, 'velocity':3, 'duration':4, 'program':5, 'tempo':6, 'timesig':7}
|
||||
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
|
||||
return self.net(input_seq, target, context=context)
|
||||
|
||||
@ -161,10 +172,12 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||
# measure_bool = (condition[:,1] == 1) # measure tokens
|
||||
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||
elif self.net.vocab.encoding_scheme == 'nb':
|
||||
elif self.net.vocab.encoding_scheme == 'nb' or self.net.vocab.encoding_scheme == 'oct':
|
||||
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
|
||||
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||
|
||||
try:
|
||||
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||
except:
|
||||
conditional_input_len = condition.shape[0]
|
||||
if conditional_input_len == 0:
|
||||
conditional_input_len = 50
|
||||
|
||||
@ -262,7 +275,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
# print(self.attribute2idx)
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
if attr not in attr_list:
|
||||
condition_filtered[:, :, idx] = 126336
|
||||
condition_filtered[:, 1:, idx] = 126336
|
||||
# rearange condition_filtered to match prediction order
|
||||
|
||||
cache = LayerIntermediates()
|
||||
@ -286,8 +299,9 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
new_idx = self.attribute2idx_after[attr]
|
||||
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
|
||||
# print("condition_step shape:", condition_step.shape)
|
||||
# print("condition_step shape:", condition_step)
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
|
||||
# print("sampled_token shape:", sampled_token)
|
||||
else:
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
||||
time_end = time.time()
|
||||
|
||||
Reference in New Issue
Block a user