模型:
ethzanalytics/mpt-7b-storywriter-sharded
这是 mpt-7b-storywriter 模型的一个版本,为了低内存加载(例如Colab),将其切分为2 GB的块。权重以bfloat16格式存储,理论上可以在CPU上运行,尽管可能需要很长时间。
请参考先前链接的仓库了解使用/实现等细节。该模型从原始仓库下载,并在相同许可下重新分发,采用Apache-2.0许可。
请注意使用时:这不是一个经过指令调整的模型,因此您需要提供足够的输入文本,以便使用提示在特定主题上生成内容。
安装/升级软件包:
pip install -U torch transformers accelerate einops
加载模型:
import torch from transformers import AutoModelForCausalLM, AutoTokenizer model_name = 'ethzanalytics/mpt-7b-storywriter-sharded' model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, revision='197d14245ad874da82194248cab1ce8cf87fa713', # optional, but a good idea device_map='auto', load_in_8bit=False, # install bitsandbytes then set to true for 8-bit ) model = torch.compile(model) tokenizer = AutoTokenizer.from_pretrained(model_name)
然后,您可以像通常一样使用 model.generate() - 有关详细信息,请参阅笔记本。