import h5py import numpy as np from torchvision import models, transforms import librosa import os from torch.utils.data import Dataset, DataLoader from PIL import Image import torch.nn as nn import torch from torch.autograd import Variable import gensim
# -----------------prepare for extract feature--------------------------
# -----------------begin to extract feature--------------------------
deftext_file(path, h5base, id, model): finnal_matrix = None with open(path + 'text.txt', 'r') as f: texts = f.read() text_list = texts.strip().split(" ") for word in text_list: try: vec_word = model[word] except: vec_word = np.zeros(300) vec_word = vec_word.reshape(300, 1) if finnal_matrix isNone: finnal_matrix = vec_word else: finnal_matrix = np.concatenate((finnal_matrix, vec_word), axis=1) h5base.create_dataset(id, data=finnal_matrix) return
# -----------------get image feature and store--------------------------
defimage_file(path, h5base, id, model): """ :param path: file path :param h5base: store 12 * 4096 data :return: """ try: img_scale = 224 batch_size = 12 data_loader = get_dataset(path, img_scale, batch_size) for _, data in enumerate(data_loader): # use model to extract feature features = model(Variable(data)) # save features h5base.create_dataset(id, data=features.data.numpy()) except: print(id + "在图像上出现了有问题!") return
# -----------------get audio feature and store--------------------------
defaudio_file(path, h5base, id): """ :param path: file path :param h5base: store 6 * 512 data :return: """ try: finnal_matrix = None use_audio = ["1.mp3", "2.mp3", "3.mp3", "4.mp3", "5.mp3", "6.mp3"] files = librosa.util.find_files(path, recurse=False, ext='mp3') for file in files: file_list = file.split("/") if file_list[-1] notin use_audio: continue y, sr = librosa.load(file) D = librosa.stft(y, n_fft=1022) vec_D = np.mean(D, axis=1, keepdims=True) if finnal_matrix isNone: finnal_matrix = vec_D else: finnal_matrix = np.concatenate((finnal_matrix, vec_D), axis=1) # judge final matrix shape if finnal_matrix isNoneor finnal_matrix.shape != (512, 6): print(id + "有问题!") return
h5base.create_dataset(id, data=finnal_matrix)
except: print(id + "在音频上出现了有问题!")
return
# -----------------get audio feature and store--------------------------
if __name__ == '__main__': # some config variable root_path = "./dataset1/" DOWNLOAD = False pre_train_weight_alexnet = "./alexnet-owt-4df8aa71.pth"
# -----------------get Alexnet to extrat imgs features-------------------------- # if you alread download the weight, we can make DOWNLOAD = False pre_alexnet = getAlexNet(DOWNLOAD) pre_alexnet.load_state_dict(torch.load(pre_train_weight_alexnet)) pretrain_dict = pre_alexnet.state_dict()
alexnet = AlexNet() alexnet_dict = alexnet.state_dict() pretrained_dict = {k: v for k, v in pretrain_dict.items() if k in alexnet_dict} # update the weight of new alexnet alexnet_dict.update(pretrained_dict) # load the new weight alexnet.load_state_dict(alexnet_dict)
# -----------------get word2vec by Google to extrat texts features-------------------------- word2vec = gensim.models.KeyedVectors.load_word2vec_format("./GoogleNews-vectors-negative300.bin", binary=True)
# go to the dir and loop dirs = os.listdir(root_path) process = 0 total = len(dirs) for dir in dirs: audio_file(root_path + dir + "/audios/", restore_audio_file, dir) image_file(root_path + dir + "/images/", restore_images_file, dir, alexnet) text_file(root_path + dir + "/texts/", restore_texts_file, dir, word2vec) process += 1 if process % (total // 10) == 0: print("alread down")