Skip to content

Commit

Permalink
fix:protobuf NetParameter message create
Browse files Browse the repository at this point in the history
fix:add PReLU support, ref:ethereon#187
  • Loading branch information
lengyue524 committed Mar 1, 2023
1 parent b48432d commit 750914e
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 106 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# self usage for openpose caffe model convert

## lib version
tensorflow 2.11.0
protobuf 3.19.6
If libs update, convert should update to support

# Caffe to TensorFlow

Convert [Caffe](https://github.com/BVLC/caffe/) models to [TensorFlow](https://github.com/tensorflow/tensorflow).
Expand Down
2 changes: 1 addition & 1 deletion convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def validate_arguments(args):

def convert(def_path, caffemodel_path, data_output_path, code_output_path, standalone_output_path, phase):
try:
sess = tf.InteractiveSession()
sess = tf.compat.v1.InteractiveSession()
transformer = TensorFlowTransformer(def_path, caffemodel_path, phase=phase)
print_stderr('Converting data...')
if data_output_path is not None:
Expand Down
112 changes: 56 additions & 56 deletions kaffe/caffe/caffe_pb2.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions kaffe/caffe/resolver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import sys

from google.protobuf import message_factory
from . import caffe_pb2
SHARED_CAFFE_RESOLVER = None

class CaffeResolver(object):
def __init__(self):
self.message_classes = message_factory.MessageFactory()
self.import_caffe()

def import_caffe(self):
Expand All @@ -21,7 +23,9 @@ def import_caffe(self):
# Use the protobuf code from the imported distribution.
# This way, Caffe variants with custom layers will work.
self.caffepb = self.caffe.proto.caffe_pb2
self.NetParameter = self.caffepb.NetParameter
self.NetParameter = self.caffepb.NetParameter
else:
self.NetParameter = self.message_classes.GetPrototype(descriptor=caffe_pb2.NETPARAMETER)

def has_pycaffe(self):
return self.caffe is not None
Expand Down
3 changes: 2 additions & 1 deletion kaffe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'Pooling': shape_pool,
'Power': shape_identity,
'ReLU': shape_identity,
'PReLU': shape_identity,
'Scale': shape_identity,
'Sigmoid': shape_identity,
'SigmoidCrossEntropyLoss': shape_scalar,
Expand Down Expand Up @@ -81,7 +82,7 @@ class NodeDispatch(object):

@staticmethod
def get_handler_name(node_kind):
if len(node_kind) <= 4:
if len(node_kind) <= 4 or node_kind == 'PReLU':
# A catch-all for things like ReLU and tanh
return node_kind.lower()
# Convert from CamelCase to under_scored
Expand Down
162 changes: 118 additions & 44 deletions kaffe/tensorflow/network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import numpy as np
import pickle
import tensorflow as tf

from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops

DEFAULT_PADDING = 'SAME'


Expand Down Expand Up @@ -41,9 +46,9 @@ def __init__(self, inputs, trainable=True):
# If true, the resulting variables are set as trainable
self.trainable = trainable
# Switch variable for dropout
self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
shape=[],
name='use_dropout')
self.use_dropout = tf.compat.v1.placeholder_with_default(tf.constant(1.0),
shape=[],
name='use_dropout')
self.setup()

def setup(self):
Expand All @@ -56,16 +61,26 @@ def load(self, data_path, session, ignore_missing=False):
session: The current TensorFlow session
ignore_missing: If true, serialized weights for missing layers are ignored.
'''
data_dict = np.load(data_path).item()
with open(data_path, 'rb') as handle:
data_dict = pickle.load(handle)
for op_name in data_dict:
with tf.variable_scope(op_name, reuse=True):
for param_name, data in data_dict[op_name].items():
with tf.compat.v1.variable_scope(op_name, reuse=True):
# TODO not sure why name mapping does not work
if 'relu' in op_name:
try:
var = tf.get_variable(param_name)
session.run(var.assign(data))
var = tf.compat.v1.get_variable(op_name)
session.run(var.assign(data_dict[op_name][0]))
except ValueError:
if not ignore_missing:
raise
else:
for param_name, data in data_dict[op_name].iteritems():
try:
var = tf.compat.v1.get_variable(param_name)
session.run(var.assign(data))
except ValueError:
if not ignore_missing:
raise

def feed(self, *args):
'''Set the input(s) for the next operation by replacing the terminal nodes.
Expand All @@ -74,7 +89,7 @@ def feed(self, *args):
assert len(args) != 0
self.terminals = []
for fed_layer in args:
if isinstance(fed_layer, str):
if isinstance(fed_layer, basestring):
try:
fed_layer = self.layers[fed_layer]
except KeyError:
Expand All @@ -95,42 +110,62 @@ def get_unique_name(self, prefix):

def make_var(self, name, shape):
'''Creates a new TensorFlow variable.'''
return tf.get_variable(name, shape, trainable=self.trainable)
return tf.compat.v1.get_variable(name, shape, trainable=self.trainable)

def validate_padding(self, padding):
'''Verifies that the padding is one of the supported ones.'''
assert padding in ('SAME', 'VALID')

def prelu_layer(self, x, weights, biases, name=None):
"""Computes PRelu(x * weight + biases).
Args:
x: a 2D tensor. Dimensions typically: batch, in_units
weights: a 2D tensor. Dimensions typically: in_units, out_units
biases: a 1D tensor. Dimensions: out_units
name: A name for the operation (optional). If not specified
"nn_prelu_layer" is used.
Returns:
A 2-D Tensor computing prelu(matmul(x, weights) + biases).
Dimensions typically: batch, out_units.
"""
with ops.name_scope(name, "prelu_layer", [x, weights, biases]) as name:
x = ops.convert_to_tensor(x, name="x")
weights = ops.convert_to_tensor(weights, name="weights")
biases = ops.convert_to_tensor(biases, name="biases")
xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
return self.parametric_relu(xw_plus_b, name=name)

@layer
def conv(self,
input,
inputs,
k_h,
k_w,
c_o,
s_h,
s_w,
name,
relu=True,
prelu=False,
padding=DEFAULT_PADDING,
group=1,
biased=True):
# Verify that the padding is acceptable
self.validate_padding(padding)
# Get the number of channels in the input
c_i = input.get_shape()[-1]
c_i = inputs.get_shape()[-1]
# Verify that the grouping parameter is valid
assert c_i % group == 0
assert c_o % group == 0
# Convolution for a given input and kernel
convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
with tf.variable_scope(name) as scope:
kernel = self.make_var('weights', shape=[k_h, k_w, int(c_i) / group, c_o])
with tf.compat.v1.variable_scope(name) as scope:
kernel = self.make_var('weights', shape=[k_h, k_w, c_i / group, c_o])
if group == 1:
# This is the common-case. Convolve the input without any further complications.
output = convolve(input, kernel)
output = convolve(inputs, kernel)
else:
# Split the input into groups and then convolve each of them independently
input_groups = tf.split(3, group, input)
input_groups = tf.split(3, group, inputs)
kernel_groups = tf.split(3, group, kernel)
output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
# Concatenate the groups
Expand All @@ -142,33 +177,65 @@ def conv(self,
if relu:
# ReLU non-linearity
output = tf.nn.relu(output, name=scope.name)
elif prelu:
output = self.parametric_relu(output, scope=scope)
return output

@layer
def relu(self, input, name):
return tf.nn.relu(input, name=name)
def relu(self, x, name):
return tf.nn.relu(x, name=name)

@layer
def prelu(self, x, name):
return self.parametric_relu(x, name=name)

def parametric_relu(self, x, scope=None, name="PReLU"):
""" PReLU.
Parametric Rectified Linear Unit. Base on:
https://github.com/tflearn/tflearn/blob/5c23566de6e614a36252a5828d107d001a0d0482/tflearn/activations.py#L188
Arguments:
x: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
`int16`, or `int8`.
name: A name for this activation op (optional).
Returns:
A `Tensor` with the same type as `x`.
"""
# tf.zeros(x.shape, dtype=dtype)
with tf.compat.v1.variable_scope(scope, default_name=name, values=[x]) as scope:
# W_init=tf.constant_initializer(0.0)
# alphas = tf.compat.v1.get_variable(name="alphas", shape=x.get_shape()[-1],
# initializer=W_init,
# dtype=tf.float32)
alphas = self.make_var(name, x.get_shape()[-1])
x = tf.nn.relu(x) + tf.multiply(alphas, (x - tf.abs(x))) * 0.5

x.scope = scope
x.alphas = alphas
return x

@layer
def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
def max_pool(self, x, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
self.validate_padding(padding)
return tf.nn.max_pool(input,
ksize=[1, k_h, k_w, 1],
strides=[1, s_h, s_w, 1],
padding=padding,
name=name)
return tf.nn.max_pool2d(x,
ksize=[1, k_h, k_w, 1],
strides=[1, s_h, s_w, 1],
padding=padding,
name=name)

@layer
def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
def avg_pool(self, x, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING):
self.validate_padding(padding)
return tf.nn.avg_pool(input,
return tf.nn.avg_pool(x,
ksize=[1, k_h, k_w, 1],
strides=[1, s_h, s_w, 1],
padding=padding,
name=name)

@layer
def lrn(self, input, radius, alpha, beta, name, bias=1.0):
return tf.nn.local_response_normalization(input,
def lrn(self, x, radius, alpha, beta, name, bias=1.0):
return tf.nn.local_response_normalization(x,
depth_radius=radius,
alpha=alpha,
beta=beta,
Expand All @@ -184,48 +251,53 @@ def add(self, inputs, name):
return tf.add_n(inputs, name=name)

@layer
def fc(self, input, num_out, name, relu=True):
with tf.variable_scope(name) as scope:
input_shape = input.get_shape()
def fc(self, x, num_out, name, relu=True, prelu=False):
with tf.compat.v1.variable_scope(name) as scope:
input_shape = x.get_shape()
if input_shape.ndims == 4:
# The input is spatial. Vectorize it first.
dim = 1
for d in input_shape[1:].as_list():
dim *= d
feed_in = tf.reshape(input, [-1, dim])
feed_in = tf.reshape(x, [-1, dim])
else:
feed_in, dim = (input, input_shape[-1].value)
feed_in, dim = (x, input_shape[-1].value)
weights = self.make_var('weights', shape=[dim, num_out])
biases = self.make_var('biases', [num_out])
op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
if relu:
op = tf.nn.relu_layer
elif prelu:
op = self.prelu_layer
else:
op = tf.compat.v1.nn.xw_plus_b
fc = op(feed_in, weights, biases, name=scope.name)
return fc

@layer
def softmax(self, input, name):
input_shape = [v.value for v in input.get_shape()]
def softmax(self, x, name):
input_shape = map(lambda v: v.value, x.get_shape())
if len(input_shape) > 2:
# For certain models (like NiN), the singleton spatial dimensions
# need to be explicitly squeezed, since they're not broadcast-able
# in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
if input_shape[1] == 1 and input_shape[2] == 1:
input = tf.squeeze(input, squeeze_dims=[1, 2])
x = tf.squeeze(x, squeeze_dims=[1, 2])
else:
raise ValueError('Rank 2 tensor input expected for softmax!')
return tf.nn.softmax(input, name=name)
return tf.nn.softmax(x, name=name)

@layer
def batch_normalization(self, input, name, scale_offset=True, relu=False):
def batch_normalization(self, x, name, scale_offset=True, relu=False, prelu=False):
# NOTE: Currently, only inference is supported
with tf.variable_scope(name) as scope:
shape = [input.get_shape()[-1]]
with tf.compat.v1.variable_scope(name) as scope:
shape = [x.get_shape()[-1]]
if scale_offset:
scale = self.make_var('scale', shape=shape)
offset = self.make_var('offset', shape=shape)
else:
scale, offset = (None, None)
output = tf.nn.batch_normalization(
input,
x,
mean=self.make_var('mean', shape=shape),
variance=self.make_var('variance', shape=shape),
offset=offset,
Expand All @@ -236,9 +308,11 @@ def batch_normalization(self, input, name, scale_offset=True, relu=False):
name=name)
if relu:
output = tf.nn.relu(output)
elif prelu:
output = self.parametric_relu(output, name=scope.name)
return output

@layer
def dropout(self, input, keep_prob, name):
def dropout(self, x, keep_prob, name):
keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
return tf.nn.dropout(input, keep, name=name)
return tf.nn.dropout(x, keep, name=name)
12 changes: 10 additions & 2 deletions kaffe/tensorflow/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..errors import KaffeError, print_stderr
from ..graph import GraphBuilder, NodeMapper
from ..layers import NodeKind
from ..transformers import (DataInjector, DataReshaper, NodeRenamer, ReLUFuser,
from ..transformers import (DataInjector, DataReshaper, NodeRenamer, ReLUFuser, PReLUFuser,
BatchNormScaleBiasFuser, BatchNormPreprocessor, ParameterNamer)

from . import network
Expand Down Expand Up @@ -69,6 +69,8 @@ def __init__(self, node, default=True):
self.inject_kwargs = {}
if node.metadata.get('relu', False) != default:
self.inject_kwargs['relu'] = not default
if node.metadata.get('prelu'):
self.inject_kwargs['prelu'] = node.metadata.get('prelu')

def __call__(self, *args, **kwargs):
kwargs.update(self.inject_kwargs)
Expand Down Expand Up @@ -104,6 +106,9 @@ def map_convolution(self, node):
def map_relu(self, node):
return TensorFlowNode('relu')

def map_prelu(self, node):
return TensorFlowNode('prelu')

def map_pooling(self, node):
pool_type = node.parameters.pool
if pool_type == 0:
Expand Down Expand Up @@ -263,7 +268,10 @@ def transform_data(self):
NodeKind.Convolution: (2, 3, 1, 0),

# (c_o, c_i) -> (c_i, c_o)
NodeKind.InnerProduct: (1, 0)
NodeKind.InnerProduct: (1, 0),

# one dimensional
NodeKind.PReLU: (0)
}),

# Pre-process batch normalization data
Expand Down
Loading

0 comments on commit 750914e

Please sign in to comment.