-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0a58456
commit 38dafeb
Showing
13 changed files
with
1,789 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from copy import deepcopy | ||
|
||
from torch.utils.data import Dataset | ||
|
||
from mmedit.registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class DummyDataset(Dataset): | ||
|
||
def __init__(self, max_length=100, batch_size=None, sample_kwargs=None): | ||
super().__init__() | ||
self.max_length = max_length | ||
self.sample_kwargs = sample_kwargs | ||
self.batch_size = batch_size | ||
|
||
def __len__(self): | ||
return self.max_length | ||
|
||
def __getitem__(self, index): | ||
data_dict = dict() | ||
input_dict = dict() | ||
if self.batch_size is not None: | ||
input_dict['num_batches'] = self.batch_size | ||
if self.sample_kwargs is not None: | ||
input_dict['sample_kwargs'] = deepcopy(self.sample_kwargs) | ||
|
||
data_dict['inputs'] = input_dict | ||
return data_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import random | ||
|
||
from mmengine.hooks import Hook | ||
from mmengine.model import is_model_wrapper | ||
|
||
from mmedit.registry import HOOKS | ||
|
||
|
||
@HOOKS.register_module() | ||
class DreamFusionTrainingHook(Hook): | ||
|
||
def __init__(self, albedo_iters: int): | ||
super().__init__() | ||
self.albedo_iters = albedo_iters | ||
|
||
self.shading_test = 'albedo' | ||
self.ambident_ratio_test = 1.0 | ||
|
||
def set_shading_and_ambient(self, runner, shading: str, | ||
ambient_ratio: str) -> None: | ||
model = runner.model | ||
if is_model_wrapper(model): | ||
model = model.module | ||
renderer = model.renderer | ||
if is_model_wrapper(renderer): | ||
renderer = renderer.module | ||
renderer.set_shading(shading) | ||
renderer.set_ambient_ratio(ambient_ratio) | ||
|
||
def after_train_iter(self, runner, batch_idx: int, *args, | ||
**kwargs) -> None: | ||
if batch_idx < self.albedo_iters or self.albedo_iters == -1: | ||
shading = 'albedo' | ||
ambient_ratio = 1.0 | ||
else: | ||
rand = random.random() | ||
if rand > 0.8: # NOTE: this should be 0.75 in paper | ||
shading = 'albedo' | ||
ambient_ratio = 1.0 | ||
elif rand > 0.4: # NOTE: this should be 0.75 * 0.5 = 0.325 | ||
shading = 'textureless' | ||
ambient_ratio = 0.1 | ||
else: | ||
shading = 'lambertian' | ||
ambient_ratio = 0.1 | ||
self.set_shading_and_ambient(runner, shading, ambient_ratio) | ||
|
||
def before_test(self, runner) -> None: | ||
self.set_shading_and_ambient(runner, self.shading_test, | ||
self.ambident_ratio_test) | ||
|
||
def before_val(self, runner) -> None: | ||
self.set_shading_and_ambient(runner, self.shading_test, | ||
self.ambident_ratio_test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .camera import DreamFusionCamera | ||
from .dreamfusion import DreamFusion | ||
from .renderer import DreamFusionRenderer | ||
from .stable_diffusion_wrapper import StableDiffusionWrapper | ||
|
||
__all__ = [ | ||
'DreamFusion', 'DreamFusionRenderer', 'DreamFusionCamera', | ||
'StableDiffusionWrapper' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
from torch.autograd import Function | ||
from torch.cuda.amp import custom_bwd, custom_fwd | ||
|
||
|
||
class _trunc_exp(Function): | ||
|
||
@staticmethod | ||
@custom_fwd(cast_inputs=torch.float) | ||
def forward(ctx, x): | ||
ctx.save_for_backward(x) | ||
return torch.exp(x) | ||
|
||
@staticmethod | ||
@custom_bwd | ||
def backward(ctx, g): | ||
x = ctx.saved_tensors[0] | ||
return g * torch.exp(x.clamp(max=15)) | ||
|
||
|
||
trunc_exp = _trunc_exp.apply |
Oops, something went wrong.