-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathraunet_nodes.py
158 lines (121 loc) · 4.67 KB
/
raunet_nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch.nn.functional as F
import comfy
from .model_patch import add_model_patch_option, patch_model_function_wrapper
class RAUNet:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"model": ("MODEL",),
"du_start": ("INT", {"default": 0, "min": 0, "max": 10000}),
"du_end": ("INT", {"default": 4, "min": 0, "max": 10000}),
"xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}),
"xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}),
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "model_update"
def model_update(self, model, du_start, du_end, xa_start, xa_end):
model = model.clone()
add_raunet_patch(model,
du_start,
du_end,
xa_start,
xa_end)
return (model,)
# This is main patch function
def add_raunet_patch(model, du_start, du_end, xa_start, xa_end):
def raunet_forward(model, x, timesteps, transformer_options, control):
if 'model_patch' not in transformer_options:
print("RAUNet: 'model_patch' not in transformer_options, skip")
return
mp = transformer_options['model_patch']
is_SDXL = mp['SDXL']
if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
print('RAUNet: model is SDXL, but input[6] != Downsample, skip')
return
if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
print('RAUNet: model is not SDXL, but input[3] != Downsample, skip')
return
if 'raunet' not in mp:
print('RAUNet: "raunet" not in model_patch options, skip')
return
if is_SDXL:
block = model.input_blocks[6][0]
else:
block = model.input_blocks[3][0]
total_steps = mp['total_steps']
step = mp['step']
ro = mp['raunet']
du_start = ro['du_start']
du_end = ro['du_end']
if step >= du_start and step < du_end:
block.op.stride = (4, 4)
block.op.padding = (2, 2)
block.op.dilation = (2, 2)
else:
block.op.stride = (2, 2)
block.op.padding = (1, 1)
block.op.dilation = (1, 1)
patch_model_function_wrapper(model, raunet_forward)
model.set_model_input_block_patch(in_xattn_patch)
model.set_model_output_block_patch(out_xattn_patch)
to = add_model_patch_option(model)
mp = to['model_patch']
if 'raunet' not in mp:
mp['raunet'] = {}
ro = mp['raunet']
ro['du_start'] = du_start
ro['du_end'] = du_end
ro['xa_start'] = xa_start
ro['xa_end'] = xa_end
def in_xattn_patch(h, transformer_options):
# both SDXL and SD15 = (input,4)
if transformer_options["block"] != ("input", 4):
# wrong block
return h
if 'model_patch' not in transformer_options:
print("RAUNet (i-x-p): 'model_patch' not in transformer_options")
return h
mp = transformer_options['model_patch']
if 'raunet' not in mp:
print("RAUNet (i-x-p): 'raunet' not in model_patch options")
return h
step = mp['step']
ro = mp['raunet']
xa_start = ro['xa_start']
xa_end = ro['xa_end']
if step < xa_start or step >= xa_end:
return h
h = F.avg_pool2d(h, kernel_size=(2,2))
return h
def out_xattn_patch(h, hsp, transformer_options):
if 'model_patch' not in transformer_options:
print("RAUNet (o-x-p): 'model_patch' not in transformer_options")
return h, hsp
mp = transformer_options['model_patch']
if 'raunet' not in mp:
print("RAUNet (o-x-p): 'raunet' not in model_patch options")
return h
step = mp['step']
is_SDXL = mp['SDXL']
ro = mp['raunet']
xa_start = ro['xa_start']
xa_end = ro['xa_end']
if is_SDXL:
if transformer_options["block"] != ("output", 5):
# wrong block
return h, hsp
else:
if transformer_options["block"] != ("output", 8):
# wrong block
return h, hsp
if step < xa_start or step >= xa_end:
return h, hsp
#error in hidiffusion codebase, size * 2 for particular sizes only
#re_size = (int(h.shape[-2] * 2), int(h.shape[-1] * 2))
re_size = (hsp.shape[-2], hsp.shape[-1])
h = F.interpolate(h, size=re_size, mode='bicubic')
return h, hsp