1021 add flexable attr control

This commit is contained in:
FelixChan
2025-10-21 15:27:03 +08:00
parent d6b68ef90b
commit b493ede479
15 changed files with 400 additions and 394 deletions

View File

@ -102,7 +102,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
'''
super().__init__()
self.net = net
# self.prediction_order = net.prediction_order
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
self.attribute2idx_after = {'pitch': 0,
'duration': 1,
'velocity': 2,
'type': 3,
'beat': 4,
'chord': 5,
'tempo': 6,
'instrument': 7}
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
return self.net(input_seq, target, context=context)
@ -164,7 +174,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
return total_out
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None):
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None,condition_step=None):
'''
Runs one step of autoregressive generation by taking the input sequence, embedding it,
passing it through the main decoder, and generating logits and a sampled token.
@ -192,7 +202,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
# Generate the next token
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature, condition_step=condition_step)
return logits, sampled_token, intermidiates, hidden_vec
def _update_total_out(self, total_out, sampled_token):
@ -225,7 +235,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
return total_out, sampled_token
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None):
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None, attr_list=None):
'''
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
until the desired maximum sequence length is reached or the end token is encountered.
@ -243,15 +253,19 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
- total_out: The generated sequence of tokens as a tensor.
'''
# Prepare the starting sequence for inference
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
# If a condition is provided, run one initial step
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
if attr_list is None:
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
else:
cache = LayerIntermediates()
# Continue generating tokens until the maximum sequence length is reached
total_out = self._prepare_inference(self.net.start_token, manual_seed, None, num_target_measures)
# for attribute-controlled generation, only keep the specified attributes in condition, others set to 126336
condition_filtered = condition.clone().unsqueeze(0)
# print(self.attribute2idx)
for attr, idx in self.attribute2idx.items():
if attr not in attr_list:
condition_filtered[:, :, idx] = 126336
# rearange condition_filtered to match prediction order
cache = LayerIntermediates()
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
bos_hidden_vec = None
hidden_vec_list = []
@ -261,7 +275,21 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
input_tensor = total_out[:, -self.net.input_length:]
# Generate the next token and update the cache
time_start = time.time()
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
# if attr_list is not None, get one token in condition_filtered each time step
if attr_list is not None:
condition_filtered = condition_filtered.to(self.net.device)
# print(condition_filtered[:,:20,:])
# print(condition_filtered.shape)
condition_step = condition_filtered[:, total_out.shape[1]-1:total_out.shape[1], :]
# rearange order, 0 to 5, 1 to 6, 2 to 7, 3 to 0, 4 to 1, 5 to 2, 6 to 3, 7 to 4
condition_step_rearranged = torch.zeros_like(condition_step)
for attr, idx in self.attribute2idx.items():
new_idx = self.attribute2idx_after[attr]
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
# print("condition_step shape:", condition_step.shape)
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
else:
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
time_end = time.time()
token_time_list.append(time_end - time_start)
if bos_hidden_vec is None:
@ -416,11 +444,11 @@ class AmadeusModel(nn.Module):
return self.decoder(input_seq, target, context=context)
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None):
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None,attr_list=None):
if batch_size == 1:
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context)
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context, attr_list=attr_list)
else:
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context)
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context, attr_list=attr_list)
class AmadeusModel4Encodec(AmadeusModel):
def __init__(