模型:
cardiffnlp/twitter-xlm-roberta-base
这是一个在约198M个多语种推特上训练的XLM-Roberta-base模型,详细描述和评估可见于 reference paper 。要在推特特定数据上评估此模型和其他语言模型,请参考 main repository 。下面提供了一个使用示例。
def preprocess(text): new_text = [] for t in text.split(" "): t = '@user' if t.startswith('@') and len(t) > 1 else t t = 'http' if t.startswith('http') else t new_text.append(t) return " ".join(new_text) def get_embedding(text): text = preprocess(text) encoded_input = tokenizer(text, return_tensors='pt') features = model(**encoded_input) features = features[0].detach().numpy() features_mean = np.mean(features[0], axis=0) return features_mean query = "Acabo de pedir pollo frito ?" #spanish tweets = ["We had a great time! ⚽️", # english "We hebben een geweldige tijd gehad! ⛩", # dutch "Nous avons passé un bon moment! ?", # french "Ci siamo divertiti! ?"] # italian d = defaultdict(int) for tweet in tweets: sim = 1-cosine(get_embedding(query),get_embedding(tweet)) d[tweet] = sim print('Most similar to: ',query) print('----------------------------------------') for idx,x in enumerate(sorted(d.items(), key=lambda x:x[1], reverse=True)): print(idx+1,x[0])
Most similar to: Acabo de pedir pollo frito ? ---------------------------------------- 1 Ci siamo divertiti! ? 2 Nous avons passé un bon moment! ? 3 We had a great time! ⚽️ 4 We hebben een geweldige tijd gehad! ⛩