-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FeatureCross Layer #13
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
And thanks for the improvements over the original, like the way the dense layers are created and used.
especially for the low-rank case). | ||
|
||
References: | ||
- [R. Wang et al.](https://arxiv.org/abs/2008.13535) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unindent so that it's lined up with References
.
|
||
Example: | ||
|
||
```python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unindent so that it's lined up with Example
.
`(input_dim, projection_dim)` and `V` is of shape | ||
`(projection_dim, input_dim)`. `projection_dim` need to be smaller | ||
than `input_dim//2` to improve the model efficiency. In practice, | ||
we've observed that `projection_dim = d/4` consistently preserved |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be input_dim/4
instead of d/4
? (I understand this was like this in the original code)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, fixed it
def build(self, input_shape: types.TensorShape) -> None: | ||
last_dim = input_shape[-1] | ||
|
||
dense_layer_args = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just move this directly line 135, no need for a local variable.
import keras | ||
|
||
|
||
def clone_initializer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this bug applies to Keras 3 (and if it does, we'll fix it). So remove this file and don't clone the initializers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this being used everywhere in KerasHub though. Removing it for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh really? Is this something that carried over from Keras 2 and Tensorflow? Or does it still apply to Keras 3 (and if it does, is it for every backend)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hertschuh - looks like it's true for all backends. Here's a short example for JAX: https://colab.research.google.com/drive/1oY7VfaFMztOoMgueOf2TuCz5lYRQ1GXs?resourcekey=0-wnk2cldy6PkkC2qV5PVDmg&usp=sharing.
This means we need clone_initializer
.
# Typecast config to `dict`. This is not really needed, | ||
# but `typing` throws an error if we don't do this. | ||
config = dict(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm.. that's the issue with using type annotation when Keras doesn't have them.
This works and I think it's a bit more elegant:
def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()
config.update({
...
})
return config
keras_rs/src/testing/test_case.py
Outdated
|
||
import keras | ||
import numpy as np | ||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, we don't want to do that, this should stay backend independent.
keras_rs/src/testing/test_case.py
Outdated
|
||
from keras_rs.src import types | ||
|
||
|
||
class TestCase(unittest.TestCase): | ||
class TestCase(tf.test.TestCase, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you need to add tf.test.TestCase
? Is it for get_temp_dir()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. If we want to avoid using tf.test.TestCase
, I can directly use tempfile
keras_rs/src/testing/test_case.py
Outdated
cls: Any, | ||
init_kwargs: Dict[Any, Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All you do is model = cls(**init_kwargs)
, so just pass the model instead as a single parameter. It's just easier and more flexible.
|
||
self.assertAllClose(self.x, output) | ||
|
||
def test_saved_model(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call it test_model_saving
, saved_model is a specific format, but that's not the one used here.
Made some changes from the TFRS version:
FeatureCross
instead ofCross
. Should we stick withCross
?The question is: do we stick with the exact same name/args which TFRS had (for easier porting to KerasRS for existing users)?