Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added "Annealed importance guidance" and DRaFT+ docs #270

Merged
merged 10 commits into from
Sep 6, 2024
Merged
104 changes: 98 additions & 6 deletions docs/user-guide/draftp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You can then run the following snipet to convert it to a ``.tar`` file:
Reward Model
############

Currently, we only have support for `Pickscore <https://arxiv.org/pdf/2305.01569.pdf>`__ reward model. Since Pickscore is a CLIP-based model,
Currently, we only have support for `Pickscore-style <https://arxiv.org/pdf/2305.01569.pdf>`__ reward models (PickScore/HPSv2). Since Pickscore is a CLIP-based model,
you can use the `conversion script <https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py>`__ from NeMo to convert it from huggingface to NeMo.

DRaFT+ Training
Expand All @@ -81,8 +81,9 @@ To launch reward model training, you must have checkpoints for `UNet <https://hu
UNET_CKPT="/path/to/unet_weights.ckpt"
VAE_CKPT="/path/to/vae_weights.bin"
RM_CKPT="/path/to/reward_model.nemo"
DRAFTP_SCRIPT="train_sd_draftp.py" # or train_sdxl_draftp.py

torchrun --nproc_per_node=2 ${GPFS}/examples/mm/stable_diffusion/train_sd_draftp.py \
torchrun --nproc_per_node=2 ${GPFS}/examples/mm/stable_diffusion/${DRAFTP_SCRIPT} \
trainer.num_nodes=1 \
trainer.devices=2 \
model.micro_batch_size=1 \
Expand All @@ -92,7 +93,7 @@ To launch reward model training, you must have checkpoints for `UNet <https://hu
model.unet_config.from_pretrained=${UNET_CKPT} \
model.first_stage_config.from_pretrained=${VAE_CKPT} \
rm.model.restore_from_path=${RM_CKPT} \
model.data.trian.webdataset.local_root_path=${TRAIN_DATA_PATH} \
model.data.train.webdataset.local_root_path=${TRAIN_DATA_PATH} \
exp_manager.create_wandb_logger=False \
exp_manager.explicit_log_dir=/results

Expand Down Expand Up @@ -135,14 +136,16 @@ To launch reward model training, you must have checkpoints for `UNet <https://hu

MOUNTS="--container-mounts=MOUNTS" # mounts

DRAFTP_SCRIPT="train_sd_draftp.py" # or train_sdxl_draftp.py

read -r -d '' cmd <<EOF
echo "*******STARTING********" \
&& echo "---------------" \
&& echo "Starting training" \
&& cd ${GPFS} \
&& export PYTHONPATH="${GPFS}:${PYTHONPATH}" \
&& export HYDRA_FULL_ERROR=1 \
&& python -u ${GPFS}/examples/nlp/gpt/train_reward_model.py \
&& python -u ${GPFS}/examples/mm/stable_diffusion/${DRAFTP_SCRIPT} \
trainer.num_nodes=1 \
trainer.devices=8 \
model.micro_batch_size=2 \
Expand All @@ -164,13 +167,102 @@ To launch reward model training, you must have checkpoints for `UNet <https://hu


.. note::
For more info on DRaFT+ hyperparameters please see the model config file:
For more info on DRaFT+ hyperparameters please see the model config files (for SD and SDXL respectively):

``NeMo-Aligner/examples/mm/stable_diffusion/conf/draftp_sd.yaml``
``NeMo-Aligner/examples/mm/stable_diffusion/conf/draftp_sdxl.yaml``

DRaFT+ Results
%%%%%%%%%%%%%%

Once you have completed fine-tuning Stable Diffusion with DRaFT+, you can run inference on your saved model using the `sd_infer.py <https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py>`__
and `sd_lora_infer.py <https://github.com/NVIDIA/NeMo/blob/main/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py>`__ scripts from the NeMo codebase. The generated images with the fine-tuned model should have
better prompt alignment and aesthetic quality.
better prompt alignment and aesthetic quality.

User controllable finetuning with Annealed Importance Guidance (AIG)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

AIG provides the inference-time flexibility to interpolate between the base Stable Diffusion model (with low rewards and high diversity) and DRaFT-finetuned model (with high rewards and low diversity) to obtain images with high rewards and high diversity. AIG inference is easily done by specifying comma-separated `weight_type` strategies to interpolate between the base and finetuned model.

.. tab-set::
.. tab-item:: AIG on Stable Diffusion XL
:sync: key2

Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case).
Weight type of the form `power_<float>` interpolates using an exponential decay specified in the AIG paper.

To run AIG inference on the terminal directly:

.. code-block:: bash

NUMNODES=1
LR=${LR:=0.00025}
INF_STEPS=${INF_STEPS:=25}
KL_COEF=${KL_COEF:=0.1}
ETA=${ETA:=0.0}
DATASET=${DATASET:="pickapic50k.tar"}
MICRO_BS=${MICRO_BS:=1}
GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4}
PEFT=${PEFT:="sdlora"}
NUM_DEVICES=${NUM_DEVICES:=8}
GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES))
LOG_WANDB=${LOG_WANDB:="False"}

echo "additional kwargs: ${ADDITIONAL_KWARGS}"

WANDB_NAME=SDXL_Draft_annealing
WEBDATASET_PATH=/path/to/${DATASET}

CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf"
CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"}
UNET_CKPT="/path/to/unet.ckpt"
VAE_CKPT="/path/to/vae.ckpt"
RM_CKPT="/path/to/reward_model.nemo"
PROMPT=${PROMPT:="Bananas growing on an apple tree"}
DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir

if [ ! -z "${ACT_CKPT}" ]; then
ACT_CKPT="model.activation_checkpointing=$ACT_CKPT "
echo $ACT_CKPT
fi

EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"}
export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1
set -x
CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \
--config-path=${CONFIG_PATH} \
--config-name=${CONFIG_NAME} \
model.optim.lr=${LR} \
model.optim.weight_decay=0.0005 \
model.optim.sched.warmup_steps=0 \
model.sampling.base.steps=${INF_STEPS} \
model.kl_coeff=${KL_COEF} \
model.truncation_steps=1 \
trainer.draftp_sd.max_epochs=5 \
trainer.draftp_sd.max_steps=10000 \
trainer.draftp_sd.save_interval=200 \
trainer.draftp_sd.val_check_interval=20 \
trainer.draftp_sd.gradient_clip_val=10.0 \
model.micro_batch_size=${MICRO_BS} \
model.global_batch_size=${GLOBAL_BATCH_SIZE} \
model.peft.peft_scheme=${PEFT} \
model.data.webdataset.local_root_path=$WEBDATASET_PATH \
rm.model.restore_from_path=${RM_CKPT} \
trainer.devices=${NUM_DEVICES} \
trainer.num_nodes=${NUMNODES} \
rm.trainer.devices=${NUM_DEVICES} \
rm.trainer.num_nodes=${NUMNODES} \
+prompt="${PROMPT}" \
exp_manager.create_wandb_logger=${LOG_WANDB} \
model.first_stage_config.from_pretrained=${VAE_CKPT} \
model.first_stage_config.from_NeMo=True \
model.unet_config.from_pretrained=${UNET_CKPT} \
model.unet_config.from_NeMo=True \
$ACT_CKPT \
exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \
exp_manager.resume_if_exists=True \
exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \
exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0'



Loading
Loading