Skip to content

Commit

Permalink
Add Torch autocast and full bf16 to GaudiStableDiffusionPipeline (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored Jun 23, 2023
1 parent 466f0af commit ae4b61f
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 214 deletions.
21 changes: 21 additions & 0 deletions docs/source/tutorials/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,24 @@ There are two different checkpoints for Stable Diffusion 2:
- use [stabilityai/stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) for generating 512x512 images

</Tip>


## Tips

To accelerate your Stable Diffusion pipeline, you can run it in full *bfloat16* precision.
This will also save memory.
You just need to pass `torch_dtype=torch.bfloat16` to `from_pretrained` when instantiating your pipeline.
Here is how to do it:

```py
import torch

pipeline = GaudiStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
scheduler=scheduler,
use_habana=True,
use_hpu_graphs=True,
gaudi_config="Habana/stable-diffusion",
torch_dtype=torch.bfloat16
)
```
6 changes: 4 additions & 2 deletions examples/stable-diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ python text_to_image_generation.py \
--image_save_dir /tmp/stable_diffusion_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion
--gaudi_config Habana/stable-diffusion \
--bf16
```

> HPU graphs are recommended when generating images by batches to get the fastest possible generations.
Expand All @@ -55,7 +56,8 @@ python text_to_image_generation.py \
--image_save_dir /tmp/stable_diffusion_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion
--gaudi_config Habana/stable-diffusion \
--bf16
```

> HPU graphs are recommended when generating images by batches to get the fastest possible generations.
Expand Down
24 changes: 19 additions & 5 deletions examples/stable-diffusion/text_to_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import sys
from pathlib import Path

import torch

from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline
from optimum.habana.utils import set_seed

Expand Down Expand Up @@ -95,7 +97,12 @@ def main():
default=None,
help="The directory where the generation pipeline will be saved.",
)
parser.add_argument("--image_save_dir", type=str, default=None, help="The directory where images will be saved.")
parser.add_argument(
"--image_save_dir",
type=str,
default="./stable-diffusion-generated-images",
help="The directory where images will be saved.",
)

parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization.")

Expand All @@ -113,6 +120,7 @@ def main():
" Precision."
),
)
parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.")

args = parser.parse_args()

Expand All @@ -126,12 +134,17 @@ def main():

# Initialize the scheduler and the generation pipeline
scheduler = GaudiDDIMScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler")
kwargs = {
"scheduler": scheduler,
"use_habana": args.use_habana,
"use_hpu_graphs": args.use_hpu_graphs,
"gaudi_config": args.gaudi_config_name,
}
if args.bf16:
kwargs["torch_dtype"] = torch.bfloat16
pipeline = GaudiStableDiffusionPipeline.from_pretrained(
args.model_name_or_path,
scheduler=scheduler,
use_habana=args.use_habana,
use_hpu_graphs=args.use_hpu_graphs,
gaudi_config=args.gaudi_config_name,
**kwargs,
)

# Set seed before running the model
Expand Down Expand Up @@ -160,6 +173,7 @@ def main():
if args.output_type == "pil":
image_save_dir = Path(args.image_save_dir)
image_save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving images in {image_save_dir.resolve()}...")
for i, image in enumerate(outputs.images):
image.save(image_save_dir / f"image_{i+1}.png")
else:
Expand Down
72 changes: 49 additions & 23 deletions optimum/habana/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,17 @@ class GaudiDiffusionPipeline(DiffusionPipeline):
gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`):
Gaudi configuration to use. Can be a string to download it from the Hub.
Or a previously initialized config can be passed.
bf16_full_eval (bool, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit.
This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
"""

def __init__(
self,
use_habana: bool = False,
use_hpu_graphs: bool = False,
gaudi_config: Union[str, GaudiConfig] = None,
bf16_full_eval: bool = False,
):
super().__init__()

Expand All @@ -104,6 +108,47 @@ def __init__(
f"`gaudi_config` must be a string or a GaudiConfig object but is {type(gaudi_config)}."
)

if self.gaudi_config.use_habana_mixed_precision or self.gaudi_config.use_torch_autocast:
if bf16_full_eval:
logger.warning(
"`use_habana_mixed_precision` or `use_torch_autocast` is True in the given Gaudi configuration but "
"`torch_dtype=torch.blfloat16` was given. Disabling mixed precision and continuing in bf16 only."
)
elif self.gaudi_config.use_torch_autocast:
# Open temporary files to write mixed-precision ops
with tempfile.NamedTemporaryFile() as hmp_bf16_file:
with tempfile.NamedTemporaryFile() as hmp_fp32_file:
self.gaudi_config.write_bf16_fp32_ops_to_text_files(
hmp_bf16_file.name,
hmp_fp32_file.name,
)
os.environ["LOWER_LIST"] = str(hmp_bf16_file)
os.environ["FP32_LIST"] = str(hmp_fp32_file)

import habana_frameworks.torch.core # noqa

elif self.gaudi_config.use_habana_mixed_precision:
try:
from habana_frameworks.torch.hpex import hmp
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.hpex. {error.msg}."
raise error

# Open temporary files to write mixed-precision ops
with tempfile.NamedTemporaryFile() as hmp_bf16_file:
with tempfile.NamedTemporaryFile() as hmp_fp32_file:
# hmp.convert needs ops to be written in text files
self.gaudi_config.write_bf16_fp32_ops_to_text_files(
hmp_bf16_file.name,
hmp_fp32_file.name,
)
hmp.convert(
opt_level=self.gaudi_config.hmp_opt_level,
bf16_file_path=hmp_bf16_file.name,
fp32_file_path=hmp_fp32_file.name,
isVerbose=self.gaudi_config.hmp_is_verbose,
)

if self.use_hpu_graphs:
try:
import habana_frameworks.torch as ht
Expand All @@ -120,29 +165,6 @@ def __init__(
error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}."
raise error
self.htcore = htcore

if self.gaudi_config.use_habana_mixed_precision:
try:
from habana_frameworks.torch.hpex import hmp
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.hpex. {error.msg}."
raise error
self.hmp = hmp

# Open temporary files to mixed-precision write ops
with tempfile.NamedTemporaryFile() as hmp_bf16_file:
with tempfile.NamedTemporaryFile() as hmp_fp32_file:
# hmp.convert needs ops to be written in text files
self.gaudi_config.write_bf16_fp32_ops_to_text_files(
hmp_bf16_file.name,
hmp_fp32_file.name,
)
self.hmp.convert(
opt_level=self.gaudi_config.hmp_opt_level,
bf16_file_path=hmp_bf16_file.name,
fp32_file_path=hmp_fp32_file.name,
isVerbose=self.gaudi_config.hmp_is_verbose,
)
else:
if use_hpu_graphs:
raise ValueError(
Expand Down Expand Up @@ -306,6 +328,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
diffusers.pipelines.pipeline_utils.LOADABLE_CLASSES = GAUDI_LOADABLE_CLASSES
diffusers.pipelines.pipeline_utils.ALL_IMPORTABLE_CLASSES = GAUDI_ALL_IMPORTABLE_CLASSES

# Define a new kwarg here to know in the __init__ whether to use mixed precision or not
bf16_full_eval = kwargs.get("torch_dtype", None) == torch.bfloat16
kwargs["bf16_full_eval"] = bf16_full_eval

return super().from_pretrained(
pretrained_model_name_or_path,
**kwargs,
Expand Down
Loading

0 comments on commit ae4b61f

Please sign in to comment.