Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support DreamFusion #1563

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmagic/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .comp1k_dataset import AdobeComp1kDataset
from .controlnet_dataset import ControlNetDataset
from .dreambooth_dataset import DreamBoothDataset
from .dummy_dataset import DummyDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .imagenet_dataset import ImageNet
from .mscoco_dataset import MSCoCoDataset
Expand All @@ -19,5 +20,6 @@
'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset',
'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset',
'MSCoCoDataset', 'ControlNetDataset', 'DreamBoothDataset',
'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset'
'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset',
'DummyDataset'
]
30 changes: 30 additions & 0 deletions mmagic/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

from torch.utils.data import Dataset

from mmagic.registry import DATASETS


@DATASETS.register_module()
class DummyDataset(Dataset):

def __init__(self, max_length=100, batch_size=None, sample_kwargs=None):
super().__init__()
self.max_length = max_length
self.sample_kwargs = sample_kwargs
self.batch_size = batch_size

def __len__(self):
return self.max_length

def __getitem__(self, index):
data_dict = dict()
input_dict = dict()
if self.batch_size is not None:
input_dict['num_batches'] = self.batch_size
if self.sample_kwargs is not None:
input_dict['sample_kwargs'] = deepcopy(self.sample_kwargs)

data_dict['inputs'] = input_dict
return data_dict
3 changes: 2 additions & 1 deletion mmagic/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dreamfusion_hook import DreamFusionTrainingHook
from .ema import ExponentialMovingAverageHook
from .iter_time_hook import IterTimerHook
from .pggan_fetch_data_hook import PGGANFetchDataHook
Expand All @@ -9,5 +10,5 @@
__all__ = [
'ReduceLRSchedulerHook', 'BasicVisualizationHook', 'VisualizationHook',
'ExponentialMovingAverageHook', 'IterTimerHook', 'PGGANFetchDataHook',
'PickleDataHook'
'PickleDataHook', 'DreamFusionTrainingHook'
]
55 changes: 55 additions & 0 deletions mmagic/engine/hooks/dreamfusion_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmagic.registry import HOOKS


@HOOKS.register_module()
class DreamFusionTrainingHook(Hook):

def __init__(self, albedo_iters: int):
super().__init__()
self.albedo_iters = albedo_iters

self.shading_test = 'albedo'
self.ambident_ratio_test = 1.0

def set_shading_and_ambient(self, runner, shading: str,
ambient_ratio: str) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
renderer = model.renderer
if is_model_wrapper(renderer):
renderer = renderer.module
renderer.set_shading(shading)
renderer.set_ambient_ratio(ambient_ratio)

def after_train_iter(self, runner, batch_idx: int, *args,
**kwargs) -> None:
if batch_idx < self.albedo_iters or self.albedo_iters == -1:
shading = 'albedo'
ambient_ratio = 1.0
else:
rand = random.random()
if rand > 0.8: # NOTE: this should be 0.75 in paper
shading = 'albedo'
ambient_ratio = 1.0
elif rand > 0.4: # NOTE: this should be 0.75 * 0.5 = 0.325
shading = 'textureless'
ambient_ratio = 0.1
else:
shading = 'lambertian'
ambient_ratio = 0.1
self.set_shading_and_ambient(runner, shading, ambient_ratio)

def before_test(self, runner) -> None:
self.set_shading_and_ambient(runner, self.shading_test,
self.ambident_ratio_test)

def before_val(self, runner) -> None:
self.set_shading_and_ambient(runner, self.shading_test,
self.ambident_ratio_test)
3 changes: 2 additions & 1 deletion mmagic/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .dim import DIM
from .disco_diffusion import ClipWrapper, DiscoDiffusion
from .dreambooth import DreamBooth
from .dreamfusion import DreamFusion
from .edsr import EDSRNet
from .edvr import EDVR, EDVRNet
from .eg3d import EG3D
Expand Down Expand Up @@ -89,5 +90,5 @@
'StyleGAN3Generator', 'InstColorization', 'NAFBaseline',
'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DenoisingUnet',
'ClipWrapper', 'EG3D', 'Restormer', 'SwinIRNet', 'StableDiffusion',
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion'
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion', 'DreamFusion'
]
10 changes: 10 additions & 0 deletions mmagic/models/editors/dreamfusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .camera import DreamFusionCamera
from .dreamfusion import DreamFusion
from .renderer import DreamFusionRenderer
from .stable_diffusion_wrapper import StableDiffusionWrapper

__all__ = [
'DreamFusion', 'DreamFusionRenderer', 'DreamFusionCamera',
'StableDiffusionWrapper'
]
22 changes: 22 additions & 0 deletions mmagic/models/editors/dreamfusion/activate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd


class _trunc_exp(Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)

@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(max=15))


trunc_exp = _trunc_exp.apply
Loading