You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
next_scales = []
B = gt_ms_idx_Bl[0].shape[0]
C = self.Cvae
H = W = self.v_patch_nums[-1]
SN = len(self.v_patch_nums)
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
pn_next: int = self.v_patch_nums[0]
for si in range(SN-1):
if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break # progressive training: not supported yet, prog_si always -1
h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')
f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))
pn_next = self.v_patch_nums[si+1]
next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))
return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
大大你好,我在阅读你的工作的代码时产生了一些疑问。
在将gt图像转化成Transformer学习的输入时,输出为r_i+1时候的输入,并非r_i,而是前面所有尺度的累计和。
这里每个尺度的next_scales 不是 self.embedding(gt_ms_idx_Bl[si]),即r_i,而是 F.interpolate(f_hat, size=(pn_next, pn_next), mode='area'),即之前所有尺度的累计和(r_1至r_i)。这与论文中的陈述有细微的差别。请问最终采取这种累计和的方式,而非单尺度token,是否有原因,比如实验效果优于论文中的叙述?
The text was updated successfully, but these errors were encountered: