diff --git a/docs/user-guide/draftp.rst b/docs/user-guide/draftp.rst index 3c67b8553..81b5a71c6 100644 --- a/docs/user-guide/draftp.rst +++ b/docs/user-guide/draftp.rst @@ -173,4 +173,159 @@ DRaFT+ Results Once you have completed fine-tuning Stable Diffusion with DRaFT+, you can run inference on your saved model using the `sd_infer.py `__ and `sd_lora_infer.py `__ scripts from the NeMo codebase. The generated images with the fine-tuned model should have -better prompt alignment and aesthetic quality. \ No newline at end of file +better prompt alignment and aesthetic quality. + +User controllable finetuning with Annealed Importance Guidance (AIG) +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +AIG provides the inference-time flexibility to interpolate between the base Stable Diffusion model (with low rewards and high diversity) and DRaFT-finetuned model (with high rewards and low diversity) to obtain images with high rewards and high diversity. AIG inference is easily done by specifying comma-separated `weight_type` strategies to interpolate between the base and finetuned model. + +.. tab-set:: + .. tab-item:: AIG on Stable Diffusion XL + :sync: key2 + + Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case). + Weight type of the form `power_` interpolates using an exponential decay specified in the AIG paper. + + To run AIG inference on the terminal directly: + + .. code-block:: bash + + NUMNODES=1 + LR=${LR:=0.00025} + INF_STEPS=${INF_STEPS:=25} + KL_COEF=${KL_COEF:=0.1} + ETA=${ETA:=0.0} + DATASET=${DATASET:="pickapic50k.tar"} + MICRO_BS=${MICRO_BS:=1} + GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} + PEFT=${PEFT:="sdlora"} + NUM_DEVICES=${NUM_DEVICES:=8} + GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES)) + LOG_WANDB=${LOG_WANDB:="False"} + + echo "additional kwargs: ${ADDITIONAL_KWARGS}" + + WANDB_NAME=SDXL_Draft_annealing + WEBDATASET_PATH=/path/to/${DATASET} + + CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" + CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"} + UNET_CKPT="/path/to/unet.ckpt" + VAE_CKPT="/path/to/vae.ckpt" + RM_CKPT="/path/to/reward_model.nemo" + PROMPT=${PROMPT:="Bananas growing on an apple tree"} + DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir + + if [ ! -z "${ACT_CKPT}" ]; then + ACT_CKPT="model.activation_checkpointing=$ACT_CKPT " + echo $ACT_CKPT + fi + + EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"} + export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1 + set -x + CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ + --config-path=${CONFIG_PATH} \ + --config-name=${CONFIG_NAME} \ + model.optim.lr=${LR} \ + model.optim.weight_decay=0.0005 \ + model.optim.sched.warmup_steps=0 \ + model.sampling.base.steps=${INF_STEPS} \ + model.kl_coeff=${KL_COEF} \ + model.truncation_steps=1 \ + trainer.draftp_sd.max_epochs=5 \ + trainer.draftp_sd.max_steps=10000 \ + trainer.draftp_sd.save_interval=200 \ + trainer.draftp_sd.val_check_interval=20 \ + trainer.draftp_sd.gradient_clip_val=10.0 \ + model.micro_batch_size=${MICRO_BS} \ + model.global_batch_size=${GLOBAL_BATCH_SIZE} \ + model.peft.peft_scheme=${PEFT} \ + model.data.webdataset.local_root_path=$WEBDATASET_PATH \ + rm.model.restore_from_path=${RM_CKPT} \ + trainer.devices=${NUM_DEVICES} \ + trainer.num_nodes=${NUMNODES} \ + rm.trainer.devices=${NUM_DEVICES} \ + rm.trainer.num_nodes=${NUMNODES} \ + +prompt="${PROMPT}" \ + exp_manager.create_wandb_logger=${LOG_WANDB} \ + model.first_stage_config.from_pretrained=${VAE_CKPT} \ + model.first_stage_config.from_NeMo=True \ + model.unet_config.from_pretrained=${UNET_CKPT} \ + model.unet_config.from_NeMo=True \ + $ACT_CKPT \ + exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ + exp_manager.resume_if_exists=True \ + exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0' + + .. tab-item:: AIG on Stable Diffusion v1.1 - v1.5 + :sync: key + + Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case). + Weight type of the form `power_` interpolates using an exponential decay specified in the AIG paper. + + To run AIG inference on the terminal directly: + + .. code-block:: bash + + LR=${LR:=0.00025} + INF_STEPS=${INF_STEPS:=25} + KL_COEF=${KL_COEF:=0.1} + ETA=${ETA:=0.0} + DATASET=${DATASET:="pickapic50k.tar"} + MICRO_BS=${MICRO_BS:=2} + GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} + PEFT=${PEFT:="sdlora"} + NUM_DEVICES=8 + GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION)) + + WANDB_NAME=SD_DRaFT_annealing + WEBDATASET_PATH=/path/to/${DATASET} + + CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" + CONFIG_NAME="draftp_sd" + UNET_CKPT="/path/to/unet.ckpt" + VAE_CKPT="/path/to/vae.ckpt" + RM_CKPT="/path/to/rewardmodel.nemo" + + # change this as an end-user + PROMPT=${PROMPT:-"Bananas growing on an apple tree"} + + EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sd.py"} + set -x + DEVICE="0,1,2,3,4,5,6,7" + echo "Running DRaFT on ${DEVICE}" + export HYDRA_FULL_ERROR=1 \ + && MASTER_PORT=15003 CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=${NUM_DEVICES} /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ + --config-path=${CONFIG_PATH} \ + --config-name=${CONFIG_NAME} \ + model.optim.lr=${LR} \ + model.optim.weight_decay=0.005 \ + model.optim.sched.warmup_steps=0 \ + model.infer.inference_steps=${INF_STEPS} \ + model.infer.eta=0.0 \ + model.kl_coeff=${KL_COEF} \ + model.truncation_steps=1 \ + trainer.draftp_sd.max_epochs=1 \ + trainer.draftp_sd.max_steps=4000 \ + trainer.draftp_sd.save_interval=100 \ + model.unet_config.from_pretrained=${UNET_CKPT} \ + model.first_stage_config.from_pretrained=${VAE_CKPT} \ + model.micro_batch_size=${MICRO_BS} \ + model.global_batch_size=${GLOBAL_BATCH_SIZE} \ + model.peft.peft_scheme=${PEFT} \ + model.data.webdataset.local_root_path=$WEBDATASET_PATH \ + rm.model.restore_from_path=${RM_CKPT} \ + +prompt="${PROMPT}" \ + trainer.draftp_sd.val_check_interval=20 \ + trainer.draftp_sd.gradient_clip_val=10.0 \ + trainer.devices=${NUM_DEVICES} \ + rm.trainer.devices=${NUM_DEVICES} \ + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ + exp_manager.resume_if_exists=True \ + exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0' + diff --git a/examples/mm/stable_diffusion/anneal_sd.py b/examples/mm/stable_diffusion/anneal_sd.py new file mode 100644 index 000000000..86c9cb387 --- /dev/null +++ b/examples/mm/stable_diffusion/anneal_sd.py @@ -0,0 +1,215 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from copy import deepcopy +from functools import partial + +import numpy as np +import torch +import torch.distributed +import torch.multiprocessing as mp +from megatron.core import parallel_state +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name +from megatron.core.utils import divide +from omegaconf.omegaconf import OmegaConf, open_dict +from packaging.version import Version +from PIL import Image +from torch import nn + +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronStableDiffusionTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.supervised import SupervisedTrainer +from nemo_aligner.data.mm import text_webdataset +from nemo_aligner.data.nlp.builders import build_dataloader +from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model +from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + retrieve_custom_trainer_state_dict, + temp_pop_from_config, +) + +mp.set_start_method("spawn", force=True) + + +def resolve_and_create_trainer(cfg, pop_trainer_key): + """resolve the cfg, remove the key before constructing the PTL trainer + and then restore it after + """ + OmegaConf.resolve(cfg) + with temp_pop_from_config(cfg.trainer, pop_trainer_key): + return MegatronStableDiffusionTrainerBuilder(cfg).create_trainer() + + +@hydra_runner(config_path="conf", config_name="draftp_sd") +def main(cfg) -> None: + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + # set cuda device for each process + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + cfg.exp_manager.create_wandb_logger = False + + if Version(torch.__version__) >= Version("1.12"): + torch.backends.cuda.matmul.allow_tf32 = True + cfg.model.data.train.dataset_path = [cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices)] + cfg.model.data.validation.dataset_path = [ + cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices) + ] + + trainer = resolve_and_create_trainer(cfg, "draftp_sd") + exp_manager(trainer, cfg.exp_manager) + logger = CustomLoggerWrapper(trainer.loggers) + # Instatiating the model here + ptl_model = MegatronSDDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device()) + init_peft(ptl_model, cfg.model) + + trainer_restore_path = trainer.ckpt_path + + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + consumed_samples = custom_trainer_state_dict["consumed_samples"] + else: + custom_trainer_state_dict = None + consumed_samples = 0 + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + train_ds, validation_ds = text_webdataset.build_train_valid_datasets( + cfg.model.data, consumed_samples=consumed_samples + ) + validation_ds = [d["captions"] for d in list(validation_ds)] + + val_dataloader = build_dataloader( + cfg, + dataset=validation_ds, + consumed_samples=consumed_samples, + mbs=cfg.model.micro_batch_size, + gbs=cfg.model.global_batch_size, + load_gbs=True, + ) + + init_using_ptl(trainer, ptl_model, val_dataloader, validation_ds) + + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + reward_model = get_reward_model(cfg.rm, mbs=cfg.model.micro_batch_size, gbs=cfg.model.global_batch_size) + ptl_model.reward_model = reward_model + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:12:00:00")) + + draft_p_trainer = SupervisedTrainer( + cfg=cfg.trainer.draftp_sd, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader=val_dataloader, + val_dataloader=val_dataloader, + test_dataloader=[], + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + ) + + if custom_trainer_state_dict is not None: + draft_p_trainer.load_state_dict(custom_trainer_state_dict) + + # Run annealed guidance + if cfg.get("prompt") is not None: + logging.info(f"Override val dataset with custom prompt: {cfg.prompt}") + val_dataloader = [[cfg.prompt]] + + wt_types = cfg.get("weight_type", None) + if wt_types is None: + wt_types = ["base", "draft", "linear", "power_2", "power_4", "step_0.6"] + else: + wt_types = wt_types.split(",") if isinstance(wt_types, str) else wt_types + logging.info(f"Running on types: {wt_types}") + + # run for all weight types + for wt_type in wt_types: + global_idx = 0 + if wt_type is None or wt_type == "base": + # dummy function that assigns a value of 0 all the time + logging.info("using the base model") + wt_draft = lambda sigma, sigma_next, i, total: 0 + else: + if wt_type == "linear": + wt_draft = lambda sigma, sigma_next, i, total: i * 1.0 / total + elif wt_type == "draft": + wt_draft = lambda sigma, sigma_next, i, total: 1 + elif wt_type.startswith("power"): # its of the form power_{power} + pow = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: (i * 1.0 / total) ** pow + elif wt_type.startswith("step"): # use a step function (step_{p}) + frac = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: float((i * 1.0 / total) >= frac) + else: + raise ValueError(f"invalid weighing type: {wt_type}") + logging.info(f"using weighing type for annealed outputs: {wt_type}.") + + # initialize generator + gen = torch.Generator(device="cpu") + gen.manual_seed((1243 + 1247837 * local_rank) % (int(2 ** 32 - 1))) + os.makedirs(f"./annealed_outputs_sd_{wt_type}/", exist_ok=True) + + for batch in val_dataloader: + batch_size = len(batch) + with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): + latents = torch.randn( + [ + batch_size, + ptl_model.in_channels, + ptl_model.height // ptl_model.downsampling_factor, + ptl_model.width // ptl_model.downsampling_factor, + ], + generator=gen, + ).to(torch.cuda.current_device()) + images = ptl_model.annealed_guidance(batch, latents, weighing_fn=wt_draft) + images = ( + images.permute(0, 2, 3, 1).detach().float().cpu().numpy().astype(np.uint8) + ) # outputs are already scaled from [0, 255] + # save to pil + for i in range(images.shape[0]): + i = i + global_idx + img_path = f"annealed_outputs_sd_{wt_type}/img_{i:05d}_{local_rank:02d}.png" + prompt_path = f"annealed_outputs_sd_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt" + Image.fromarray(images[i]).save(img_path) + with open(prompt_path, "w") as fi: + fi.write(batch[i]) + # increment global index + global_idx += batch_size + logging.info("Saved all images.") + + +if __name__ == "__main__": + main()