数据集:
nuprl/MultiPL-E
MultiPL-E 是用于评估支持 18 种编程语言的大型语言模型的代码生成数据集。它采用了 OpenAI 的 "HumanEval" 数据集和 MBPP Python 基准测试,并使用小型编译器将它们转换为其他语言。可以轻松添加对新语言和基准测试的支持。
对于大多数情况,您应该使用名为 "SRCDATA-LANG" 的变体,其中 "SRCDATA" 是 "humaneval" 或 "mbpp",而 "LANG" 是支持的语言之一。我们使用每种语言的规范文件扩展名来标识语言,例如 Python 的 "py",C++ 的 "cpp",Lua 的 "lua" 等。
我们还提供了一些其他的变体:
"SRCDATA-LANG-keep" 与 "SRCDATA-LANG" 相同,但是提示文本完全不变。如果原始提示中包含 Python 的 doctest,则它们仍保持为 Python,而不是被转换为 "LANG"。如果原始提示中包含 Python 特定的术语,例如 "list",它仍为 "list",而不是被转换为 "LANG" 中的 "vector"(例如 C++)。
"SRCDATA-LANG-transform" 将 doctest 转换为 "LANG",但保持提示的自然语言文本不变。
"SRCDATA-LANG-removed" 从提示中删除 doctest。
需要注意的是,MBPP 没有任何 doctest,因此不适用于 "removed" 和 "transform" 的变体。
以下脚本使用 Salesforce/codegen 模型生成 Lua 代码,并使用 MultiPL-E 为 luaunit 生成带有单元测试的脚本。
import datasets from transformers import AutoTokenizer, AutoModelForCausalLM LANG = "lua" MODEL_NAME = "Salesforce/codegen-350M-multi" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).half().cuda() problems = datasets.load_dataset("nuprl/MultiPL-E", f"humaneval-{LANG}") def stop_at_stop_token(decoded_string, problem): """ Truncates the output at stop tokens, taking care to skip the prompt which may have stop tokens. """ min_stop_index = len(decoded_string) for stop_token in problem["stop_tokens"]: stop_index = decoded_string.find(stop_token) if stop_index != -1 and stop_index > len(problem["prompt"]) and stop_index < min_stop_index: min_stop_index = stop_index return decoded_string[:min_stop_index] for problem in problems["test"]: input_ids = tokenizer( problem["prompt"], return_tensors="pt", ).input_ids.cuda() generated_ids = model.generate( input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id + 2 ) truncated_string = stop_at_stop_token(tokenizer.decode(generated_ids[0]), problem) filename = problem["name"] + "." + LANG with open(filename, "w") as f: print(f"Created {filename}") f.write(truncated_string) f.write("\n") f.write(problem["tests"])