Skip to content

Commit

Permalink
Merge branch 'fastmachinelearning:main' into oneapi_2025
Browse files Browse the repository at this point in the history
  • Loading branch information
laurilaatu authored Jan 13, 2025
2 parents 7df2c25 + 5c85e9d commit 89929ec
Show file tree
Hide file tree
Showing 19 changed files with 289 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
args: ["--profile", "black", --line-length=125]

- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
rev: v3.19.1
hooks:
- id: pyupgrade
args: ["--py36-plus"]
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, name):
attrs.append(ConfigurableAttribute('reuse_factor', default=1, description=descriptions.reuse_factor))
self.attribute_map[layer] = attrs

# seperable is kind of special because it is effectively two layers that will be split
# separable is kind of special because it is effectively two layers that will be split
for layer in (SeparableConv1D, SeparableConv2D):
attrs = self.attribute_map.get(layer, [])
attrs.append(TypeAttribute('depthwise_accum'))
Expand Down
40 changes: 37 additions & 3 deletions hls4ml/backends/oneapi/passes/convolution_templates.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from hls4ml.backends.backend import get_backend
from hls4ml.backends.oneapi.oneapi_template import StreamFunctionCallTemplate, TaskSequenceTemplate
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm
from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm, DepthwiseConv1D, DepthwiseConv2D

# TODO - Dilation rate ?

Expand Down Expand Up @@ -70,9 +70,20 @@
conv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_conv1d_stream.h']


depthconv1d_function_template = (
'nnet::depthwise_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
)
depthconv1d_include_list = [
'nnet_utils/nnet_conv1d.h',
'nnet_utils/nnet_conv1d_resource.h',
'nnet_utils/nnet_depthconv1d.h',
'nnet_utils/nnet_depthconv1d_resource.h',
]


class Conv1DConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(Conv1D)
super().__init__((Conv1D, DepthwiseConv1D))
self.template = conv1d_config_template
self.mult_template = conv_mult_config_template

Expand Down Expand Up @@ -137,6 +148,12 @@ def format(self, node):
return self.template.format(**params)


class DepthwiseConv1DFunctionTemplate(Conv1DFunctionTemplate):
def __init__(self):
super(Conv1DFunctionTemplate, self).__init__(DepthwiseConv1D, include_header=depthconv1d_include_list)
self.template = depthconv1d_function_template


''' 2D Conv '''
conv2d_config_template = """struct config{index} : nnet::conv2d_config {{
static const unsigned in_height = {in_height};
Expand Down Expand Up @@ -183,7 +200,7 @@ def format(self, node):

class Conv2DConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__((Conv2D, Conv2DBatchnorm))
super().__init__((Conv2D, Conv2DBatchnorm, DepthwiseConv2D))
self.template = conv2d_config_template
self.mult_template = conv_mult_config_template

Expand Down Expand Up @@ -233,3 +250,20 @@ def format(self, node):
raise RuntimeError('channels_first not supported on oneAPI')
params['data_format'] = 'cl'
return self.template.format(**params)


depthconv2d_function_template = (
'nnet::depthwise_conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
)
depthconv2d_include_list = [
'nnet_utils/nnet_conv2d.h',
'nnet_utils/nnet_conv2d_resource.h',
'nnet_utils/nnet_depthconv2d.h',
'nnet_utils/nnet_depthconv2d_resource.h',
]


class DepthwiseConv2DFunctionTemplate(Conv2DFunctionTemplate):
def __init__(self):
super(Conv2DFunctionTemplate, self).__init__(DepthwiseConv2D, include_header=depthconv2d_include_list)
self.template = depthconv2d_function_template
19 changes: 19 additions & 0 deletions hls4ml/converters/pytorch/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
import numpy as np

from hls4ml.converters.pytorch_to_hls import pytorch_handler


@pytorch_handler('Constant')
def parse_constant_layer(operation, layer_name, node):
assert 'Constant' in operation

layer = {}
layer['inputs'] = []

layer['class_name'] = 'Constant'
layer['name'] = layer_name

constant = np.array(node._args)
layer['value'] = constant
output_shape = constant.shape

return layer, output_shape


@pytorch_handler('Linear')
def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert 'Linear' in operation
Expand Down
43 changes: 37 additions & 6 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch

from hls4ml.model import ModelGraph
Expand Down Expand Up @@ -159,6 +160,23 @@ def parse_pytorch_model(config, verbose=True):

n_inputs = 0

# check for constant nodes
merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax']
i = 0 # count number of consts and use it in the name
for node in traced_model.graph.nodes:
if node.name.split('_')[0] in merge_layers:
for arg in node.args:
if np.isscalar(arg):
# add an input node with the constant value
new_node = traced_model.graph.placeholder(
name='const_' + str(i), type_expr=torch.Tensor, default_value=arg
)
node.prepend(new_node)
node.update_arg(1, new_node)
i += 1

traced_model.graph.lint()

for node in traced_model.graph.nodes:
if node.op == 'call_module':
# modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x',
Expand Down Expand Up @@ -249,13 +267,26 @@ def parse_pytorch_model(config, verbose=True):

input_layer = {}
input_layer['name'] = node.name
input_layer['class_name'] = 'InputLayer'
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
layer_list.insert(n_inputs, input_layer)

output_shapes[input_layer['name']] = list(input_shapes[n_inputs])
input_layers.append(input_layer['name'])
n_inputs += 1
if 'const' in node.name:
pytorch_class = 'Constant'
layer, output_shape = layer_handlers[pytorch_class](pytorch_class, node.name, node)

layer_list.append(layer)

assert output_shape is not None
output_shapes[layer['name']] = output_shape

else:

input_layer['class_name'] = 'InputLayer'
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
layer_list.insert(n_inputs, input_layer)

output_shapes[input_layer['name']] = list(input_shapes[n_inputs])

input_layers.append(input_layer['name'])
n_inputs += 1

layer_counter += 1

Expand Down
2 changes: 1 addition & 1 deletion hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
'convert',
[
'channels_last_converter',
'seperable_to_depthwise_and_conv',
'separable_to_depthwise_and_conv',
'remove_transpose_before_flatten',
'remove_nop_transpose',
'remove_single_channel_transpose',
Expand Down
7 changes: 6 additions & 1 deletion hls4ml/model/optimizer/passes/convert_to_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,17 @@ def transform(self, model, node):
if (
isinstance(node, Reshape)
and len(node.attributes['target_shape']) == 1
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal"
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "off"
):
previous_node = node.get_input_node(node.inputs[0])
input = previous_node.name
outshape = previous_node.get_output_variable().shape

if (model.config.config['IOType'] == 'io_stream') and len(outshape) == 3:
raise Exception(
'No 3D transpose available in io_stream, this model cannot be converted to channels-last'
)

if len(outshape) == 2:
attributes = {'perm': [1, 0]}
else:
Expand Down
10 changes: 5 additions & 5 deletions hls4ml/model/optimizer/passes/seperable_to_dw_conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This optimizer converts a seperable convolution to a depthwise followed by a regular convolution.
This optimizer converts a separable convolution to a depthwise followed by a regular convolution.
For backends with a custom pointwise implementations the regular convolution will subsequently
be converted to a pointwise convolution by a different optimizer.
"""
Expand All @@ -10,8 +10,8 @@
from hls4ml.model.optimizer import OptimizerPass


class SeperableToDepthwiseAndConv(OptimizerPass):
"""Convert Seperable to DepthwiseConv + Conv (potentially later Pointwise)"""
class SeparableToDepthwiseAndConv(OptimizerPass):
"""Convert Separable to DepthwiseConv + Conv (potentially later Pointwise)"""

_dw_attributes = (
'in_width',
Expand Down Expand Up @@ -70,7 +70,7 @@ def transform(self, model, node):
model.config.parse_name_config(dw_name, dw_layer_config)

# creating the attributes
dw_attributes = {k: node.attributes[k] for k in SeperableToDepthwiseAndConv._dw_attributes if k in node.attributes}
dw_attributes = {k: node.attributes[k] for k in SeparableToDepthwiseAndConv._dw_attributes if k in node.attributes}
dw_attributes['n_filt'] = dw_attributes['n_chan'] * dw_attributes['depth_multiplier']
dw_attributes['use_bias'] = False

Expand Down Expand Up @@ -100,7 +100,7 @@ def transform(self, model, node):
model.config.parse_name_config(pw_name, pw_layer_config)

# creating the attributes
pw_attributes = {k: node.attributes[k] for k in SeperableToDepthwiseAndConv._pw_attributes if k in node.attributes}
pw_attributes = {k: node.attributes[k] for k in SeparableToDepthwiseAndConv._pw_attributes if k in node.attributes}
pw_attributes['filt_width'] = 1
pw_attributes['filt_height'] = 1
pw_attributes['stride_width'] = 1
Expand Down
19 changes: 19 additions & 0 deletions hls4ml/templates/oneapi/firmware/nnet_utils/nnet_depthconv1d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef NNET_DEPTH_CONV1D_H_
#define NNET_DEPTH_CONV1D_H_

#include "nnet_common.h"
#include "nnet_conv1d.h"
#include "nnet_depthconv1d_resource.h"

namespace nnet {

template <class data_T, class res_T, typename CONFIG_T>
void depthwise_conv_1d_cl(const data_T &data, res_T &res, const typename CONFIG_T::weight_t &weights,
const typename CONFIG_T::bias_t &biases) {

depthwise_conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}

} // namespace nnet

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#ifndef NNET_DEPTH_CONV1D_LATENCY_H_
#define NNET_DEPTH_CONV1D_LATENCY_H_

#include "nnet_common.h"
#include "nnet_conv1d_resource.h"
#include "nnet_mult.h"

namespace nnet {

template <class data_T, class res_T, typename CONFIG_T>
void depthwise_conv_1d_resource_cl(const data_T &data, res_T &res, const typename CONFIG_T::weight_t &weights,
const typename CONFIG_T::bias_t &biases) {

int depth_multiplier = CONFIG_T::n_filt / CONFIG_T::n_chan;
[[intel::fpga_register]] int res_idx = 0;

[[intel::fpga_register]] typename CONFIG_T::accum_t acc[CONFIG_T::out_width * CONFIG_T::n_filt];

DM_LOOP:
#pragma unroll
for (int dm = 0; dm < depth_multiplier; dm++) {

WIDTH_LOOP:
#pragma unroll
for (int w = 0; w < CONFIG_T::out_width; w++) {

CHAN_LOOP:
#pragma unroll
for (int c = 0; c < CONFIG_T::n_chan; c++) {

res_idx = (w * CONFIG_T::n_filt) + (c * depth_multiplier) + dm;

acc[res_idx] = biases[c * depth_multiplier + dm];

KERNEL_W_LOOP:
#pragma unroll
for (int kw = 0; kw < CONFIG_T::filt_width; kw++) {

int w_in = w * CONFIG_T::stride_width + kw - CONFIG_T::pad_left;

if ((w_in >= 0) && (w_in < CONFIG_T::in_width)) {

acc[res_idx] += CONFIG_T::mult_config::
template product<typename data_T::value_type, typename CONFIG_T::weight_t::value_type>::product(
data[(w_in)*CONFIG_T::n_chan + c],
weights[(dm * CONFIG_T::filt_width * CONFIG_T::n_chan) + (kw * CONFIG_T::n_chan) + c]);
}
}
}
}
}

RESULT:
#pragma unroll
for (int ires = 0; ires < CONFIG_T::out_width * CONFIG_T::n_filt; ires++) {
res[ires] = cast<typename CONFIG_T::accum_t, typename res_T::value_type, CONFIG_T>(acc[ires]);
}
}
} // namespace nnet
#endif
19 changes: 19 additions & 0 deletions hls4ml/templates/oneapi/firmware/nnet_utils/nnet_depthconv2d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef NNET_DEPTH_CONV2D_H_
#define NNET_DEPTH_CONV2D_H_

#include "nnet_common.h"
#include "nnet_conv2d.h"
#include "nnet_depthconv2d_resource.h"

namespace nnet {

template <class data_T, class res_T, typename CONFIG_T>
void depthwise_conv_2d_cl(const data_T &data, res_T &res, const typename CONFIG_T::weight_t &weights,
const typename CONFIG_T::bias_t &biases) {

depthwise_conv_2d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
}

} // namespace nnet

#endif
Loading

0 comments on commit 89929ec

Please sign in to comment.