-
Notifications
You must be signed in to change notification settings - Fork 424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bit overflow with softmax #887
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if the optimizer should really be in the backends directory. I think optimizers in model/optimizers
should not concern themselves with backends.
Moved |
test/pytest/test_softmax.py
Outdated
x[:, imax] *= 10 | ||
return x | ||
def normal_dist(shape): | ||
return np.clip(np.random.normal(0, 8, shape), -32, 31) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we getting rid of existing distributions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing distribution generation does not feel necessary, and I don't see a good reason to run tests twice on uniform and uniform+manual outliers. A gaussian will also give variables of many scales, giving some outliers by default.
test/pytest/test_softmax.py
Outdated
X = generate_data | ||
model = tf.keras.models.Sequential() | ||
model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax')) | ||
model.compile() | ||
|
||
f_type = 'ac_fixed<18,8,true,AC_RND,AC_SAT>' if backend == 'Quartus' else 'ap_fixed<18,8,AP_RND,AP_SAT>' | ||
f_type = ( | ||
f'ac_fixed<{table_bits},true,AC_RND,AC_SAT>' if backend == 'Quartus' else f'ap_fixed<{table_bits},AP_RND,AP_SAT>' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a new standard for tests that use multiple backends is to just do fixed<W,I>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are inherited from the old test file. Will SAT
/ RND
flags be also auto-converted? Will change.
test/pytest/test_softmax.py
Outdated
cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') | ||
cfg['LayerName']['softmax']['Strategy'] = strategy | ||
cfg['LayerName']['softmax']['inv_table_t'] = f_type | ||
cfg['LayerName']['softmax']['exp_table_t'] = f_type | ||
cfg['LayerName']['softmax_input']['Precision']['result'] = f'ap_fixed<{input_bits}>' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to the comment above, here you're actually relying on what is essentially a bug to correctly parse even in case of Quartus.
test/pytest/test_softmax.py
Outdated
y_hls4ml = hls_model.predict(X).reshape(y_keras.shape) | ||
keras_trace = hls4ml.model.profiling.get_ymodel_keras(model, X) | ||
np.testing.assert_allclose(y_hls4ml, keras_trace['dense'], rtol=0, atol=2e-2) | ||
y_keras = model.layers[0](X).numpy() # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to do it like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original code was redundant (and confusing to me): what is done here is essentially skipping the last layer in Keras. As there are only two, applying the first layer only is enough.
Will revert if you think the original one is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want just model.predict(X)
. The section looks like it was cut from a different test. We should use the opportunity to clean up (also the surrounding comments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case I shall purge this part, should an equivalent test already exists somewhere.
A# Description
Fix for #885 : softmax working with less than 10 bits of input/table size. Tested with Vivado/Vitis/quartus.
Type of change
For a new feature or function, please create an issue first to discuss it
with us before submitting a pull request.
Note: Please delete options that are not relevant.
Tests
Test Configuration:
pytest/test_softmax.py
Checklist
pre-commit
on the files I edited or added.