-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathinference.py
69 lines (57 loc) · 2.45 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import torch
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
from blora_utils import BLOCKS, filter_lora, scale_lora
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt", type=str, required=True, help="B-LoRA prompt"
)
parser.add_argument(
"--output_path", type=str, required=True, help="path to save the images"
)
parser.add_argument(
"--content_B_LoRA", type=str, default=None, help="path for the content B-LoRA"
)
parser.add_argument(
"--style_B_LoRA", type=str, default=None, help="path for the style B-LoRA"
)
parser.add_argument(
"--content_alpha", type=float, default=1., help="alpha parameter to scale the content B-LoRA weights"
)
parser.add_argument(
"--style_alpha", type=float, default=1., help="alpha parameter to scale the style B-LoRA weights"
)
parser.add_argument(
"--num_images_per_prompt", type=int, default=4, help="number of images per prompt"
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16).to("cuda")
# Get Content B-LoRA SD
if args.content_B_LoRA is not None:
content_B_LoRA_sd, _ = pipeline.lora_state_dict(args.content_B_LoRA)
content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content'])
content_B_LoRA = scale_lora(content_B_LoRA, args.content_alpha)
else:
content_B_LoRA = {}
# Get Style B-LoRA SD
if args.style_B_LoRA is not None:
style_B_LoRA_sd, _ = pipeline.lora_state_dict(args.style_B_LoRA)
style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style'])
style_B_LoRA = scale_lora(style_B_LoRA, args.style_alpha)
else:
style_B_LoRA = {}
# Merge B-LoRAs SD
res_lora = {**content_B_LoRA, **style_B_LoRA}
# Load
pipeline.load_lora_into_unet(res_lora, None, pipeline.unet)
# Generate
images = pipeline(args.prompt, num_images_per_prompt=args.num_images_per_prompt).images
# Save
for i, img in enumerate(images):
img.save(f'{args.output_path}/{args.prompt}_{i}.jpg')