1021 add flexable attr control
This commit is contained in:
@ -102,7 +102,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
'''
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
# self.prediction_order = net.prediction_order
|
||||
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
|
||||
self.attribute2idx_after = {'pitch': 0,
|
||||
'duration': 1,
|
||||
'velocity': 2,
|
||||
'type': 3,
|
||||
'beat': 4,
|
||||
'chord': 5,
|
||||
'tempo': 6,
|
||||
'instrument': 7}
|
||||
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
|
||||
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
|
||||
return self.net(input_seq, target, context=context)
|
||||
|
||||
@ -164,7 +174,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
|
||||
return total_out
|
||||
|
||||
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None):
|
||||
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None,condition_step=None):
|
||||
'''
|
||||
Runs one step of autoregressive generation by taking the input sequence, embedding it,
|
||||
passing it through the main decoder, and generating logits and a sampled token.
|
||||
@ -192,7 +202,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
|
||||
|
||||
# Generate the next token
|
||||
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
|
||||
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature, condition_step=condition_step)
|
||||
return logits, sampled_token, intermidiates, hidden_vec
|
||||
|
||||
def _update_total_out(self, total_out, sampled_token):
|
||||
@ -225,7 +235,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
return total_out, sampled_token
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None):
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None, attr_list=None):
|
||||
'''
|
||||
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
|
||||
until the desired maximum sequence length is reached or the end token is encountered.
|
||||
@ -243,15 +253,19 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
- total_out: The generated sequence of tokens as a tensor.
|
||||
'''
|
||||
# Prepare the starting sequence for inference
|
||||
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
||||
|
||||
# If a condition is provided, run one initial step
|
||||
if condition is not None:
|
||||
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
||||
if attr_list is None:
|
||||
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
||||
else:
|
||||
cache = LayerIntermediates()
|
||||
|
||||
# Continue generating tokens until the maximum sequence length is reached
|
||||
total_out = self._prepare_inference(self.net.start_token, manual_seed, None, num_target_measures)
|
||||
# for attribute-controlled generation, only keep the specified attributes in condition, others set to 126336
|
||||
condition_filtered = condition.clone().unsqueeze(0)
|
||||
# print(self.attribute2idx)
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
if attr not in attr_list:
|
||||
condition_filtered[:, :, idx] = 126336
|
||||
# rearange condition_filtered to match prediction order
|
||||
|
||||
cache = LayerIntermediates()
|
||||
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
|
||||
bos_hidden_vec = None
|
||||
hidden_vec_list = []
|
||||
@ -261,7 +275,21 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
input_tensor = total_out[:, -self.net.input_length:]
|
||||
# Generate the next token and update the cache
|
||||
time_start = time.time()
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
||||
# if attr_list is not None, get one token in condition_filtered each time step
|
||||
if attr_list is not None:
|
||||
condition_filtered = condition_filtered.to(self.net.device)
|
||||
# print(condition_filtered[:,:20,:])
|
||||
# print(condition_filtered.shape)
|
||||
condition_step = condition_filtered[:, total_out.shape[1]-1:total_out.shape[1], :]
|
||||
# rearange order, 0 to 5, 1 to 6, 2 to 7, 3 to 0, 4 to 1, 5 to 2, 6 to 3, 7 to 4
|
||||
condition_step_rearranged = torch.zeros_like(condition_step)
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
new_idx = self.attribute2idx_after[attr]
|
||||
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
|
||||
# print("condition_step shape:", condition_step.shape)
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
|
||||
else:
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
||||
time_end = time.time()
|
||||
token_time_list.append(time_end - time_start)
|
||||
if bos_hidden_vec is None:
|
||||
@ -416,11 +444,11 @@ class AmadeusModel(nn.Module):
|
||||
return self.decoder(input_seq, target, context=context)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None):
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None,attr_list=None):
|
||||
if batch_size == 1:
|
||||
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context)
|
||||
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context, attr_list=attr_list)
|
||||
else:
|
||||
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context)
|
||||
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context, attr_list=attr_list)
|
||||
|
||||
class AmadeusModel4Encodec(AmadeusModel):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user