模型:
mrm8488/bloom-560m-finetuned-sd-prompts
这个模型是基于 bigscience/bloom-560m 数据集上进行微调的版本。在评估集上取得了以下结果:
import torch from transformers import BloomTokenizerFast, BloomForCausalLM device = 'cuda' if torch.cuda.is_available() else 'cpu' ckpt = 'mrm8488/bloom-560m-finetuned-sd-prompts' tokenizer = BloomTokenizerFast.from_pretrained(ckpt) model = BloomForCausalLM.from_pretrained(ckpt).to(device) def generate_prompt(text): inputs = tokenizer(text, return_tensors='pt') input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) output = model.generate(input_ids, attention_mask=attention_mask, repetition_penalty=1.05, max_length=2048, eos_token_id=tokenizer.eos_token_id) return tokenizer.decode(output[0], skip_special_tokens=False) text = "<s>Prompt: pikachu dinning in the eiffel tower" generate_prompt(text) # Output: <s>Prompt: pikachu dinning in the eiffel tower, intricate, elegant, highly detailed, digital painting, artstation, concept art, smooth, sharp focus, illustration, art by artgerm and greg rutkowski and alphonse mucha</s>
需要更多信息
需要更多信息
需要更多信息
训练时使用了以下超参数:
Training Loss | Epoch | Step | Validation Loss |
---|---|---|---|
2.6743 | 0.17 | 100 | 2.0891 |
1.8919 | 0.33 | 200 | 1.7191 |
1.5907 | 0.5 | 300 | 1.4454 |
1.3865 | 0.67 | 400 | 1.3247 |
1.2487 | 0.83 | 500 | 1.2150 |
1.1565 | 1.0 | 600 | 1.1031 |
0.896 | 1.17 | 700 | 1.0612 |
0.8389 | 1.33 | 800 | 0.9994 |
0.8071 | 1.5 | 900 | 0.9530 |
0.7628 | 1.67 | 1000 | 0.9206 |
0.7423 | 1.83 | 1100 | 0.8883 |
0.7155 | 2.0 | 1200 | 0.8742 |