Skip to content

Commit

Permalink
update setting activation scale for diffusers
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 15, 2025
1 parent fe55db5 commit 99586b1
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 69 deletions.
59 changes: 21 additions & 38 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,9 @@ def _set_runtime_options(
):
for model_name in models_and_export_configs.keys():
_, sub_export_config = models_and_export_configs[model_name]
sub_export_config.runtime_options = {}
if (
"diffusers" in library_name
or "text-generation" in task
or ("image-text-to-text" in task and model_name == "language_model")
):
if not hasattr(sub_export_config, "runtime_options"):
sub_export_config.runtime_options = {}
if "text-generation" in task or ("image-text-to-text" in task and model_name == "language_model"):
sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0"
if not quantized_model and (
"text-generation" in task or ("image-text-to-text" in task and model_name == "language_model")
Expand Down Expand Up @@ -999,41 +996,22 @@ def _get_submodels_and_export_configs(
def get_diffusion_models_for_export_ext(
pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "openvino"
):
if is_diffusers_version(">=", "0.29.0"):
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline

sd3_pipes = [StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline]
if is_diffusers_version(">=", "0.30.0"):
from diffusers import StableDiffusion3InpaintPipeline

sd3_pipes.append(StableDiffusion3InpaintPipeline)

is_sd3 = isinstance(pipeline, tuple(sd3_pipes))
else:
is_sd3 = False

if is_diffusers_version(">=", "0.30.0"):
from diffusers import FluxPipeline
is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL")
is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3")
is_flux = pipeline.__class__.__name__.startswith("Flux")
is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3

flux_pipes = [FluxPipeline]

if is_diffusers_version(">=", "0.31.0"):
from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline

flux_pipes.extend([FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline])

if is_diffusers_version(">=", "0.32.0"):
from diffusers import FluxFillPipeline

flux_pipes.append(FluxFillPipeline)
if not is_sd3 and not is_flux:
models_for_export = get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
if is_sdxl and pipeline.vae.config.force_upcast:
models_for_export["vae_encoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "128.0"}
models_for_export["vae_decoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "128.0"}

is_flux = isinstance(pipeline, tuple(flux_pipes))
else:
is_flux = False
if is_sd and pipeline.scheduler.config.prediction_type == "v_prediction":
models_for_export["vae_encoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["vae_decoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}

if not is_sd3 and not is_flux:
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
if is_sd3:
elif is_sd3:
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
else:
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
Expand Down Expand Up @@ -1135,6 +1113,7 @@ def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
int_dtype=int_dtype,
float_dtype=float_dtype,
)
export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["text_encoder_3"] = (text_encoder_3, export_config)

return models_for_export
Expand Down Expand Up @@ -1172,6 +1151,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
transformer_export_config = export_config_constructor(
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
)
transformer_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["transformer"] = (transformer, transformer_export_config)

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
Expand All @@ -1187,6 +1167,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
vae_encoder_export_config = vae_config_constructor(
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
vae_encoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
Expand All @@ -1202,6 +1183,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
vae_decoder_export_config = vae_config_constructor(
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
vae_decoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)

text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
Expand All @@ -1218,6 +1200,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
int_dtype=int_dtype,
float_dtype=float_dtype,
)
export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)

return models_for_export
28 changes: 23 additions & 5 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
)

from ...exporters.openvino import main_export
from ..utils.import_utils import is_diffusers_version
from ..utils.import_utils import is_diffusers_version, is_openvino_version
from .configuration import OVConfig, OVQuantizationMethod, OVWeightQuantizationConfig
from .loaders import OVTextualInversionLoaderMixin
from .modeling_base import OVBaseModel
Expand All @@ -75,6 +75,7 @@
_print_compiled_model_properties,
model_has_dynamic_inputs,
np_to_pt_generators,
check_scale_available,
)


Expand Down Expand Up @@ -484,8 +485,15 @@ def _from_pretrained(
ov_config = kwargs.get("ov_config", {})
device = kwargs.get("device", "CPU")
vae_ov_conifg = {**ov_config}
if "GPU" in device.upper() and "INFERENCE_PRECISION_HINT" not in vae_ov_conifg:
vae_ov_conifg["INFERENCE_PRECISION_HINT"] = "f32"
if (
"GPU" in device.upper()
and "INFERENCE_PRECISION_HINT" not in vae_ov_conifg
and is_openvino_version("<=", "2025.0")
):
vae_model_path = models["vae_decoder"]
required_upcast = check_scale_available(vae_model_path)
if required_upcast:
vae_ov_conifg["INFERENCE_PRECISION_HINT"] = "f32"
for name, path in models.items():
if name in kwargs:
models[name] = kwargs.pop(name)
Expand Down Expand Up @@ -1202,7 +1210,12 @@ def forward(
return ModelOutput(**model_outputs)

def _compile(self):
if "GPU" in self._device and "INFERENCE_PRECISION_HINT" not in self.ov_config:
if (
"GPU" in self._device
and "INFERENCE_PRECISION_HINT" not in self.ov_config
and is_openvino_version("<", "2025.0")
and check_scale_available(self.model)
):
self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"})
super()._compile()

Expand Down Expand Up @@ -1241,7 +1254,12 @@ def forward(
return ModelOutput(**model_outputs)

def _compile(self):
if "GPU" in self._device and "INFERENCE_PRECISION_HINT" not in self.ov_config:
if (
"GPU" in self._device
and "INFERENCE_PRECISION_HINT" not in self.ov_config
and is_openvino_version("<", "2025.0")
and check_scale_available(self.model)
):
self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"})
super()._compile()

Expand Down
18 changes: 18 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,21 @@ def onexc(func, path, exc):
def cleanup(self):
if self._finalizer.detach() or os.path.exists(self.name):
self._rmtree(self.name, ignore_errors=self._ignore_cleanup_errors)


def check_scale_available(model: Union[Model, str, Path]):
if isinstance(model, Model):
return model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
if not Path(model).exists():
return False
import xml.etree.ElementTree as ET

tree = ET.parse(model)
root = tree.getroot()
rt_info = root.find("rt_info")
if rt_info is None:
return False
runtime_options = rt_info.find("runtime_options")
if runtime_options is None:
return False
return runtime_options.find("ACTIVATIONS_SCALE_FACTOR") is not None
60 changes: 34 additions & 26 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class ExportModelTest(unittest.TestCase):
"llava": OVModelForVisualCausalLM,
}

EXPECTED_DIFFUSERS_SCALE_FACTORS = {
"stable-diffusion-xl": {"vae_encoder": "128.0", "vae_decoder": "128.0"},
"stable-diffusion-3": {"text_encoder_3": "8.0"},
"flux": {"text_encoder_2": "8.0", "transformer": "8.0", "vae_encoder": "8.0", "vae_decoder": "8.0"},
"stable-diffusion-xl-refiner": {"vae_encoder": "128.0", "vae_decoder": "128.0"},
}

if is_transformers_version(">=", "4.45"):
SUPPORTED_ARCHITECTURES.update({"stable-diffusion-3": OVStableDiffusion3Pipeline, "flux": OVFluxPipeline})

Expand Down Expand Up @@ -143,32 +150,33 @@ def _openvino_export(
)

if library_name == "diffusers":
self.assertTrue(
ov_model.vae_encoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
self.assertTrue(
ov_model.vae_decoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
if hasattr(ov_model, "text_encoder") and ov_model.text_encoder:
self.assertTrue(
ov_model.text_encoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
if hasattr(ov_model, "text_encoder_2") and ov_model.text_encoder_2:
self.assertTrue(
ov_model.text_encoder_2.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
if hasattr(ov_model, "text_encoder_3") and ov_model.text_encoder_3:
self.assertTrue(
ov_model.text_encoder_3.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
if hasattr(ov_model, "unet") and ov_model.unet:
self.assertTrue(
ov_model.unet.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
if hasattr(ov_model, "transformer") and ov_model.transformer:
self.assertTrue(
ov_model.transformer.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
expected_scale_factors = self.EXPECTED_DIFFUSERS_SCALE_FACTORS.get(model_type, {})
components = [
"unet",
"transformer",
"text_encoder",
"text_encoder_2",
"text_encoder_3",
"vae_encoder",
"vae_decoder",
]
for component in components:
component_model = getattr(ov_model, component, None)
if component_model is None:
continue
component_scale = expected_scale_factors.get(component)
if component_scale is not None:
self.assertTrue(
component_model.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
self.assertEqual(
component_model.model.get_rt_info()["runtime_options"]["ACTIVATIONS_SCALE_FACTOR"],
component_scale,
)
else:
self.assertFalse(
component_model.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_export(self, model_type: str):
Expand Down

0 comments on commit 99586b1

Please sign in to comment.