Skip to content

Commit

Permalink
first commit for dreamfusion
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Dec 30, 2022
1 parent 0a58456 commit 38dafeb
Show file tree
Hide file tree
Showing 13 changed files with 1,789 additions and 5 deletions.
3 changes: 2 additions & 1 deletion mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .basic_image_dataset import BasicImageDataset
from .cifar10_dataset import CIFAR10
from .comp1k_dataset import AdobeComp1kDataset
from .dummy_dataset import DummyDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .imagenet_dataset import ImageNet
from .mscoco_dataset import MSCoCoDataset
Expand All @@ -15,5 +16,5 @@
'AdobeComp1kDataset', 'BasicImageDataset', 'BasicFramesDataset',
'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset',
'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset',
'MSCoCoDataset'
'MSCoCoDataset', 'DummyDataset'
]
30 changes: 30 additions & 0 deletions mmedit/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

from torch.utils.data import Dataset

from mmedit.registry import DATASETS


@DATASETS.register_module()
class DummyDataset(Dataset):

def __init__(self, max_length=100, batch_size=None, sample_kwargs=None):
super().__init__()
self.max_length = max_length
self.sample_kwargs = sample_kwargs
self.batch_size = batch_size

def __len__(self):
return self.max_length

def __getitem__(self, index):
data_dict = dict()
input_dict = dict()
if self.batch_size is not None:
input_dict['num_batches'] = self.batch_size
if self.sample_kwargs is not None:
input_dict['sample_kwargs'] = deepcopy(self.sample_kwargs)

data_dict['inputs'] = input_dict
return data_dict
3 changes: 2 additions & 1 deletion mmedit/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dreamfusion_hook import DreamFusionTrainingHook
from .ema import ExponentialMovingAverageHook
from .iter_time_hook import GenIterTimerHook
from .pggan_fetch_data_hook import PGGANFetchDataHook
Expand All @@ -9,5 +10,5 @@
__all__ = [
'ReduceLRSchedulerHook', 'BasicVisualizationHook', 'GenVisualizationHook',
'ExponentialMovingAverageHook', 'GenIterTimerHook', 'PGGANFetchDataHook',
'PickleDataHook'
'PickleDataHook', 'DreamFusionTrainingHook'
]
55 changes: 55 additions & 0 deletions mmedit/engine/hooks/dreamfusion_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmedit.registry import HOOKS


@HOOKS.register_module()
class DreamFusionTrainingHook(Hook):

def __init__(self, albedo_iters: int):
super().__init__()
self.albedo_iters = albedo_iters

self.shading_test = 'albedo'
self.ambident_ratio_test = 1.0

def set_shading_and_ambient(self, runner, shading: str,
ambient_ratio: str) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
renderer = model.renderer
if is_model_wrapper(renderer):
renderer = renderer.module
renderer.set_shading(shading)
renderer.set_ambient_ratio(ambient_ratio)

def after_train_iter(self, runner, batch_idx: int, *args,
**kwargs) -> None:
if batch_idx < self.albedo_iters or self.albedo_iters == -1:
shading = 'albedo'
ambient_ratio = 1.0
else:
rand = random.random()
if rand > 0.8: # NOTE: this should be 0.75 in paper
shading = 'albedo'
ambient_ratio = 1.0
elif rand > 0.4: # NOTE: this should be 0.75 * 0.5 = 0.325
shading = 'textureless'
ambient_ratio = 0.1
else:
shading = 'lambertian'
ambient_ratio = 0.1
self.set_shading_and_ambient(runner, shading, ambient_ratio)

def before_test(self, runner) -> None:
self.set_shading_and_ambient(runner, self.shading_test,
self.ambident_ratio_test)

def before_val(self, runner) -> None:
self.set_shading_and_ambient(runner, self.shading_test,
self.ambident_ratio_test)
3 changes: 2 additions & 1 deletion mmedit/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
FeedbackBlockHeatmapAttention, LightCNN, MaxFeature)
from .dim import DIM
from .disco_diffusion import ClipWrapper, DiscoDiffusion
from .dreamfusion import DreamFusion
from .edsr import EDSRNet
from .edvr import EDVR, EDVRNet
from .eg3d import EG3D
Expand Down Expand Up @@ -87,5 +88,5 @@
'StyleGAN3Generator', 'InstColorization', 'NAFBaseline',
'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DDIMScheduler',
'DDPMScheduler', 'DenoisingUnet', 'ClipWrapper', 'EG3D', 'Restormer',
'SwinIRNet', 'StableDiffusion'
'SwinIRNet', 'StableDiffusion', 'DreamFusion'
]
10 changes: 10 additions & 0 deletions mmedit/models/editors/dreamfusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .camera import DreamFusionCamera
from .dreamfusion import DreamFusion
from .renderer import DreamFusionRenderer
from .stable_diffusion_wrapper import StableDiffusionWrapper

__all__ = [
'DreamFusion', 'DreamFusionRenderer', 'DreamFusionCamera',
'StableDiffusionWrapper'
]
22 changes: 22 additions & 0 deletions mmedit/models/editors/dreamfusion/activate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd


class _trunc_exp(Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)

@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(max=15))


trunc_exp = _trunc_exp.apply
Loading

0 comments on commit 38dafeb

Please sign in to comment.