from email.mime import audio import torch from pathlib import Path import json from collections import defaultdict from omegaconf import OmegaConf, DictConfig from transformers import T5Tokenizer, T5EncoderModel import gradio as gr import os, sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from Amadeus.train_utils import adjust_prediction_order from Amadeus.evaluation_utils import ( wandb_style_config_to_omega_config, ) from Amadeus.symbolic_encoding import decoding_utils from data_representation import vocab_utils from Amadeus import model_zoo from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor # === 保持原来的工具函数 === def get_best_ckpt_path_and_config(dir): if dir is None: raise ValueError('No such code in wandb_dir') ckpt_dir = dir / 'files' / 'checkpoints' config_path = dir / 'files' / 'config.yaml' vocab_path = next(ckpt_dir.glob('vocab*')) if len(list(ckpt_dir.glob('*last.pt'))) > 0: last_ckpt_fn = next(ckpt_dir.glob('*last.pt')) else: pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', ''))) last_ckpt_fn = pt_fns[-1] return last_ckpt_fn, config_path, vocab_path def prepare_model_and_dataset_from_config(config: DictConfig, vocab_path: str): nn_params = config.nn_params vocab_path = Path(vocab_path) encoding_scheme = config.nn_params.encoding_scheme num_features = config.nn_params.num_features vocab_name = {'remi': 'LangTokenVocab', 'cp': 'MusicTokenVocabCP', 'nb': 'MusicTokenVocabNB'} selected_vocab_name = vocab_name[encoding_scheme] vocab = getattr(vocab_utils, selected_vocab_name)( in_vocab_file_path=vocab_path, event_data=None, encoding_scheme=encoding_scheme, num_features=num_features) prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params) AmadeusModel = getattr(model_zoo, nn_params.model_name)( vocab=vocab, input_length=config.train_params.input_length, prediction_order=prediction_order, input_embedder_name=nn_params.input_embedder_name, main_decoder_name=nn_params.main_decoder_name, sub_decoder_name=nn_params.sub_decoder_name, sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0, sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \ if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False, dim=nn_params.main_decoder.dim_model, heads=nn_params.main_decoder.num_head, depth=nn_params.main_decoder.num_layer, dropout=nn_params.model_dropout, ) return AmadeusModel, vocab def load_resources(wandb_exp_dir, device): wandb_exp_dir = Path(wandb_exp_dir) ckpt_path, config_path, vocab_path = get_best_ckpt_path_and_config( wandb_exp_dir ) config = OmegaConf.load(config_path) config = wandb_style_config_to_omega_config(config) ckpt = torch.load(ckpt_path, map_location=device) model, vocab = prepare_model_and_dataset_from_config(config, vocab_path) model.load_state_dict(ckpt['model'], strict=False) model.to(device) model.eval() torch.compile(model) print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) return config, model, vocab import time def generate_with_text_prompt(config, vocab, model, device, prompt, text_encoder_model, sampling_method='top_p', threshold=0.99, temperature=1.15, generation_length=1024): encoding_scheme = config.nn_params.encoding_scheme tokenizer = T5Tokenizer.from_pretrained(text_encoder_model) encoder = T5EncoderModel.from_pretrained(text_encoder_model).to(device) context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(device) context = encoder(**context).last_hidden_state in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4} in_beat_resolution = in_beat_resolution_dict.get(config.dataset, 4) midi_decoder_dict = {'remi': 'MidiDecoder4REMI', 'cp': 'MidiDecoder4CP', 'nb': 'MidiDecoder4NB'} decoder_name = midi_decoder_dict[encoding_scheme] decoder = getattr(decoding_utils, decoder_name)( vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset ) generated_sample = model.generate( 0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context ) if encoding_scheme == 'nb': generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, config.data_params.first_pred_feature) # === 生成带时间戳的文件名 === timestamp = time.strftime("%Y%m%d_%H%M%S") Path("outputs").mkdir(exist_ok=True) output_file = Path("outputs") / f"generated_{timestamp}.mid" decoder(generated_sample, output_path=str(output_file)) return str(output_file) # === Gradio Demo === device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_id = "models/Amadeus-S" # 模型路径,可以是 Amadeus-S, Amadeus-M, Amadeus-L # check if model exists if not Path(model_id).exists(): # download from huggingface import os from huggingface_hub import snapshot_download os.makedirs("models", exist_ok=True) local_dir = snapshot_download( repo_id="longyu1315/Amadeus-S", repo_type="model", local_dir="models" ) print("模型已下载到:", local_dir) config, model, vocab = load_resources(model_id, device) # 示例 prompts examples = { "prompt1": "A melodic electronic ambient song with a touch of darkness, set in the key of E major and a 4/4 time signature. Tubular bells, electric guitar, synth effects, synth pad, and oboe weave together to create an epic, space-like atmosphere. The tempo is a steady Andante, and the chord progression of A, B, and E forms the harmonic backbone of this captivating piece.", "prompt2": "A melodic electronic song with a moderate tempo, featuring a blend of drums, piano, brass section, alto saxophone, and synth bass. The piece is set in B minor and follows a chord progression of C#m, B, A, and B. With a duration of 252 seconds, it evokes a dreamy and relaxing atmosphere, perfect for corporate settings.", "prompt3": " A soothing pop song that evokes feelings of love and relaxation, featuring a gentle blend of piano, flute, violin, and acoustic guitar. Set in the key of C major with a 4/4 time signature, the piece moves at an Andante tempo, creating a meditative and emotional atmosphere. The chord progression of G, C, F, G, and C adds to the song's calming ambiance.", "prompt4": "A lively and melodic rock song with a touch of pop, featuring pizzicato strings that add a playful and upbeat vibe. The piece is set in A minor and maintains a fast tempo of 148 beats per minute, with a 4/4 time signature. The chord progression of C, G, Fmaj7, C, and G repeats throughout the song, creating a catchy and energetic atmosphere that's perfect for corporate or background music.", } def gradio_generate(prompt, threshold, temperature, length): if "Amadeus-M" in model_id or "Amadeus-L" in model_id: encoder_choice ="large" else: encoder_choice = "base" text_encoder_model = 'google/flan-t5-base' if encoder_choice == 'base' else 'google/flan-t5-large' midi_path = generate_with_text_prompt( config, vocab, model, device, prompt, text_encoder_model, threshold=threshold, temperature=temperature, generation_length=length, ) # === 根据 MIDI 文件名生成对应的 WAV 文件名 === audio_path = midi_path.replace('.mid', '.wav').replace('generated', 'music/generated') return midi_path, audio_path with gr.Blocks() as demo: gr.Markdown("# 🎵 Amadeus MIDI Generation Demo") gr.Markdown( "### 🎵 Prompt 输入指南\n" "请尽量包含以下要素:\n" "- 曲风(如 pop, electronic, ambient...)\n" "- 乐器(如 piano, guitar, drums, strings...)\n" "- 调式(如 C major, F# minor...)\n" "- 拍号(如 4/4, 3/4...)\n" "- 速度(如 120 BPM, Andante, Allegro...)\n" "- 和弦走向(如 C, G, Am, F...)\n" "- 情绪(如 happy, relaxing, motivational...)" "推荐从示例中选择初始 Prompt 进行修改。" ) with gr.Row(): prompt = gr.Textbox(label="输入文本描述 (Prompt)", placeholder="A lively rock and electronic fusion, this song radiates happiness and energy. Distorted guitars, a rock organ, and driving drums propel the melody forward in a fast-paced 4/4 time signature. Set in the key of A major, it features a chord progression of E, D, A/G, E, and D, creating a dynamic and engaging sound that would be right at home in a video game soundtrack.") with gr.Row(): threshold = gr.Slider(0.5, 1.0, 0.99, step=0.01, label="阈值") temperature = gr.Slider(0.5, 3.0, 1.25, step=0.05, label="温度") length = gr.Slider(256, 3072, 1024, step=128, label="生成长度") generate_btn = gr.Button("生成 MIDI 🎼") midi_file = gr.File(label="下载生成的 MIDI 文件") audio_output = gr.Audio(label="生成的音频预览", type="filepath") generate_btn.click(fn=gradio_generate, inputs=[prompt, threshold, temperature, length], outputs=[midi_file, audio_output]) gr.Markdown("### 示例 Prompt\n" "prompt1: A melodic electronic ambient song with a touch of darkness, set in the key of E major and a 4/4 time signature. Tubular bells, electric guitar, synth effects, synth pad, and oboe weave together to create an epic, space-like atmosphere. The tempo is a steady Andante, and the chord progression of A, B, and E forms the harmonic backbone of this captivating piece.\n\n" "prompt2: A melodic electronic song with a moderate tempo, featuring a blend of drums, piano, brass section, alto saxophone, and synth bass. The piece is set in B minor and follows a chord progression of C#m, B, A, and B. With a duration of 252 seconds, it evokes a dreamy and relaxing atmosphere, perfect for corporate settings.\n\n" "prompt3: A soothing pop song that evokes feelings of love and relaxation, featuring a gentle blend of piano, flute, violin, and acoustic guitar. Set in the key of C major with a 4/4 time signature, the piece moves at an Andante tempo, creating a meditative and emotional atmosphere. The chord progression of G, C, F, G, and C adds to the song's calming ambiance.\n\n" "prompt4: A lively and melodic rock song with a touch of pop, featuring pizzicato strings that add a playful and upbeat vibe. The piece is set in A minor and maintains a fast tempo of 148 beats per minute, with a 4/4 time signature. The chord progression of C, G, Fmaj7, C, and G repeats throughout the song, creating a catchy and energetic atmosphere that's perfect for corporate or background music." ) with gr.Row(): for name, text in examples.items(): # show text on button click btn = gr.Button(name) btn.click(lambda t=text: t, None, prompt) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=True)