Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hls4ml Optimization API [Part 1] #768

Merged
merged 25 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8ba4206
Optimization API config files & model attribute builder
bo3z Apr 14, 2023
12fba05
Optimization sparsity schedulers
bo3z Apr 14, 2023
d95c956
Utils & regularizers for Keras optimization
bo3z Apr 14, 2023
de51797
Knapsack solver & unstructured pruning objective
bo3z Apr 14, 2023
a49a113
Keras optimization masking and weight removal logic
bo3z Apr 14, 2023
e655ab6
Remove unused channels w/ Keras Surgeon
bo3z Apr 14, 2023
399a98d
Hyperparameter tuning for pruning & weight sharing regularization
bo3z Apr 14, 2023
47392ba
Top-level Keras optimization function & GPU FLOPs optimization
bo3z Apr 14, 2023
a778e39
hls4ml objectives & top-level optimization function
bo3z Apr 14, 2023
7cd25a0
Add docs for hls4ml Optimization API
bo3z Apr 14, 2023
f792ea6
Full support for multi-objective Vivado optimisation
bo3z Jun 16, 2023
82779ff
part of pre-commit of Optimization API pt.1
bo3z Jun 16, 2023
aac3f1f
Fix missing packages; rename PyTests & pre-commit pt 2.
bo3z Sep 15, 2023
5b4488e
Merge branch 'main' into hls4ml-optimization-api-part-1
bo3z Sep 15, 2023
7c2d128
Fix failing tests & GitHub warnings
bo3z Sep 15, 2023
e044a12
Merge branch 'main' into hls4ml-optimization-api-part-1
bo3z Oct 9, 2023
4aff443
Merge branch 'master' into hls4ml-optimization-api-part-1
bo3z Dec 4, 2023
1156ff9
Merge branch 'fastmachinelearning:hls4ml-optimization-api-part-1' int…
bo3z Dec 4, 2023
c4a5a0f
Fix optimization failing PyTests
bo3z Dec 4, 2023
7a26a9a
Cleanup docstrings
vloncar Jan 25, 2024
ad47f41
Merge pull request #3 from vloncar/opt1
bo3z Jan 25, 2024
4675607
Fix docstring in ObjectiveEstimator
vloncar Jan 29, 2024
8503c86
Add optimization API paper to reference.rst
vloncar Jan 29, 2024
6eb391f
Rename optimize_keras_for_hls4ml
vloncar Jan 29, 2024
e66d7e7
Merge pull request #4 from vloncar/opt1
bo3z Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions docs/advanced/model_optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
========================
hls4ml Optimization API
========================

Pruning and weight sharing are effective techniques to reduce model footprint and computational requirements. The hls4ml Optimization API introduces hardware-aware pruning and weight sharing.
By defining custom objectives, the algorithm solves a Knapsack optimization problem aimed at maximizing model performance, while keeping the target resource(s) at a minimum. Out-of-the box objectives include network sparsity, GPU FLOPs, Vivado DSPs, memory utilization etc.

The code block below showcases three use cases of the hls4ml Optimization API - network sparsity (unstructured pruning), GPU FLOPs (structured pruning) and Vivado DSP utilization (pattern pruning). First, we start with unstructured pruning:

.. code-block:: Python
from sklearn.metrics import accuracy_score
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.losses import CategoricalCrossentropy
from hls4ml.optimization.keras import optimize_model
from hls4ml.optimization.keras.utils import get_model_sparsity
from hls4ml.optimization.attributes import get_attributes_from_keras_model
from hls4ml.optimization.objectives import ParameterEstimator
from hls4ml.optimization.scheduler import PolynomialScheduler
# Define baseline model and load data
# X_train, y_train = ...
# X_val, y_val = ...
# X_test, y_test = ...
# baseline_model = ...
# Evaluate baseline model
y_baseline = baseline_model.predict(X_test)
acc_base = accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_baseline, axis=1))
sparsity, layers = get_model_sparsity(baseline_model)
print(f'Baseline Keras accuracy: {acc_base}')
print(f'Baseline Keras sparsity, overall: {sparsity}')
print(f'Baseline Keras sparsity, per-layer: {layers}')
# Defining training parameters
# Epochs refers to the number of maximum epochs to train a model, after imposing some sparsity
# If the model is pre-trained, a good rule of thumb is to use between a 1/3 and 1/2 of the number of epochs used to train baseline model
epochs = 10
batch_size = 128
metric = 'accuracy'
optimizer = Adam()
loss_fn = CategoricalCrossentropy(from_logits=True)

# Define the metric to monitor, as well as if its increasing or decreasing
# This disctinction allows us to optimize both regression and classification models
# In regression, e.g. minimize validation MSE & for classification e.g. maximize accuracy
metric, increasing = CategoricalAccuracy(), True
# Relative tolerance (rtol) is the the relative loss in metric the optimized model is allowed to incur
rtol = 0.975

# A scheduler defines how the sparsity is incremented at each step
# In this case, the maximum sparsity is 50% and it will be applied at a polynomially decreasing rate, for 10 steps
# If the final sparsity is unspecified, it is set to 100%
# The optimization algorithm stops either when (i) the relative drop in performance is below threshold or (ii) final sparsity reached
scheduler = PolynomialScheduler(5, final_sparsity=0.5)
# Get model attributes
model_attributes = get_attributes_from_keras_model(baseline_model)

# Optimize model
# ParameterEstimator is the objective and, in this case, the objective is to minimize the total number of parameters
optimized_model = optimize_model(
baseline_model, model_attributes, ParameterEstimator, scheduler,
X_train, y_train, X_val, y_val, batch_size, epochs, optimizer, loss_fn, metric, increasing, rtol
)
# Evaluate optimized model
y_optimized = optimized_model.predict(X_test)
acc_optimized = accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_optimized, axis=1))
sparsity, layers = get_model_sparsity(optimized_model)
print(f'Optimized Keras accuracy: {acc_optimized}')
print(f'Optimized Keras sparsity, overall: {sparsity}')
print(f'Opimized Keras sparsity, per-layer: {layers}')

In a similar manner, it is possible to target GPU FLOPs or Vivado DSPs. However, in that case, sparsity is not equivalent to model sparsity.
Instead, it is the sparsity of the target resource. As an example: Starting with a network utilizing 512 DSPs and a final sparsity of 50%; the optimized network will use 256 DSPs.

To optimize GPU FLOPs, the code is similar to above:
.. code-block:: Python
from hls4ml.optimization.objectives.gpu_objectives import GPUFLOPEstimator

# Optimize model
# Note the change from ParameterEstimator to GPUFLOPEstimator
optimized_model = optimize_model(
baseline_model, model_attributes, GPUFLOPEstimator, scheduler,
X_train, y_train, X_val, y_val, batch_size, epochs, optimizer, loss_fn, metric, increasing, rtol
)
bo3z marked this conversation as resolved.
Show resolved Hide resolved
# Evaluate optimized model
y_optimized = optimized_model.predict(X_test)
acc_optimized = accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_optimized, axis=1))
print(f'Optimized Keras accuracy: {acc_optimized}')
# Note the difference in total number of parameters
# Optimizing GPU FLOPs is equivalent to removing entire structures (filters, neurons) from the network
print(baseline_model.summary())
print(optimized_model.summary())

Finally, optimizing Vivado DSPs is possible, given a hls4ml config:
.. code-block:: Python
from hls4ml.utils.config import config_from_keras_model
from hls4ml.optimization.objectives.vivado_objectives import VivadoDSPEstimator

# Note the change from optimize_model to optimize_keras_for_hls4ml
# The function optimize_keras_for_hls4ml acts as a wrapper for the function, parsing hls4ml config to model attributes
from hls4ml.optimization import optimize_keras_for_hls4ml

# Create hls4ml config
default_reuse_factor = 4
default_precision = 'ac_fixed<16, 6>'
hls_config = config_from_keras_model(baseline_model, granularity='name', default_precision=default_precision, default_reuse_factor=default_reuse_factor)
hls_config['IOType'] = 'io_parallel'
hls_config['Model']['Strategy'] = 'Resource' # Strategy must be present for optimisation

# Optimize model
# Note the change from ParameterEstimator to VivadoDSPEstimator
optimized_model = optimize_keras_for_hls4ml(
baseline_model, model_attributes, VivadoDSPEstimator, scheduler,
X_train, y_train, X_val, y_val, batch_size, epochs, optimizer, loss_fn, metric, increasing, rtol
)

There are two more Vivado "optimizers" - VivadoFFEstimator, aimed at reducing register utilisation and VivadoMultiObjectiveEstimator, aimed at optimising BRAM and DSP utilisation.
Note, to ensure DSPs are optimized, "unrolled" Dense multiplication must be used before synthesing HLS, by modifying the config:
.. code-block:: Python
hls_config = config_from_keras_model(optimized_model)
hls_config['Model']['DenseResourceImplementation'] = 'Unrolled'
# Any addition hls4ml config, such as strategy, reuse factor etc...
97 changes: 97 additions & 0 deletions hls4ml/optimization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np

from hls4ml.optimization.attributes import get_attributes_from_keras_model_and_hls4ml_config
from hls4ml.optimization.keras import optimize_model


def optimize_keras_for_hls4ml(
keras_model,
hls_config,
objective,
scheduler,
X_train,
y_train,
X_val,
y_val,
batch_size,
epochs,
optimizer,
loss_fn,
validation_metric,
increasing,
rtol,
callbacks=[],
ranking_metric='l1',
local=False,
verbose=False,
rewinding_epochs=1,
cutoff_bad_trials=3,
directory='hls4ml-optimization',
tuner='Bayesian',
knapsack_solver='CBC_MIP',
regularization_range=np.logspace(-6, -2, num=16).tolist(),
):
'''
Top-level function for optimizing a Keras model, given hls4ml config and a hardware objective(s)

Args:
- keras_model (keras.Model): Model to be optimized
- hls_config (dict): hls4ml configuration, obtained from hls4ml.utils.config.config_from_keras_model(...)
- objective (hls4ml.optimization.objectives.ObjectiveEstimator): Parameter, hardware or user-defined objective of optimization
- scheduler (hls4ml.optimization.schduler.OptimizationScheduler): Sparsity scheduler, choose between constant, polynomial and binary
- X_train (np.array): Training inputs
- y_train (np.array): Training labels
- X_val (np.array): Validation inputs
- y_val (np.array): Validation labels
- batch_size (int): Batch size during training
- epochs (int): Maximum number of epochs to fine-tune model, in one iteration of pruning
- optimizer (keras.optimizers.Optimizer or equivalent-string description): Optimizer used during training
- loss_fn (keras.losses.Loss or equivalent loss description): Loss function used during training
- validation_metric (keras.metrics.Metric or equivalent loss description): Validation metric, used as a baseline
- increasing (boolean): If the metric improves with increased values; e.g. accuracy -> increasing = True, MSE -> increasing = False
- rtol (float): Relative tolerance; pruning stops when pruned_validation_metric < (or >) rtol * baseline_validation_metric

Kwargs:
- callbacks (list of keras.callbacks.Callback) Currently not supported, developed in future versions
- ranking_metric (string): Metric used for rannking weights and structures; currently supported l1, l2, saliency and Oracle
- local (boolean): Layer-wise or global pruning
- verbose (boolean): Display debug logs during model optimization
- rewinding_epochs (int): Number of epochs to retrain model without weight freezing, allows regrowth of previously pruned weights
- cutoff_bad_trials (int): After how many bad trials (performance below threshold), should model pruning / weight sharing stop
- directory (string): Directory to store temporary results
- tuner (str): Tuning alogorithm, choose between Bayesian, Hyperband and None
- knapsack_solver (str): Algorithm to solve Knapsack problem when optimizing; default usually works well; for very large networks, greedy algorithm might be more suitable
- regularization_range (list): List of suitable hyperparameters for weight decay
'''

# Extract model attributes
model_attributes = get_attributes_from_keras_model_and_hls4ml_config(keras_model, hls_config)

# Optimize model
return optimize_model(
keras_model,
model_attributes,
objective,
scheduler,
X_train,
y_train,
X_val,
y_val,
batch_size,
epochs,
optimizer,
loss_fn,
validation_metric,
increasing,
rtol,
callbacks=callbacks,
ranking_metric=ranking_metric,
local=local,
verbose=verbose,
rewinding_epochs=rewinding_epochs,
cutoff_bad_trials=cutoff_bad_trials,
directory=directory,
tuner=tuner,
knapsack_solver=knapsack_solver,
regularization_range=regularization_range,
)
Loading