Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FeatureCross Layer #13

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Jan 23, 2025

Made some changes from the TFRS version:

  • Named it FeatureCross instead of Cross. Should we stick with Cross?
  • Renamed some args.
  • And of course, all the normal TF -> Keras 3.0 changes.
  • Cleaned up doc-strings and unit tests.

The question is: do we stick with the exact same name/args which TFRS had (for easier porting to KerasRS for existing users)?

@abheesht17 abheesht17 requested a review from hertschuh January 23, 2025 12:05
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

And thanks for the improvements over the original, like the way the dense layers are created and used.

especially for the low-rank case).

References:
- [R. Wang et al.](https://arxiv.org/abs/2008.13535)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unindent so that it's lined up with References.


Example:

```python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unindent so that it's lined up with Example.

`(input_dim, projection_dim)` and `V` is of shape
`(projection_dim, input_dim)`. `projection_dim` need to be smaller
than `input_dim//2` to improve the model efficiency. In practice,
we've observed that `projection_dim = d/4` consistently preserved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be input_dim/4 instead of d/4? (I understand this was like this in the original code)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, fixed it

def build(self, input_shape: types.TensorShape) -> None:
last_dim = input_shape[-1]

dense_layer_args = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just move this directly line 135, no need for a local variable.

import keras


def clone_initializer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this bug applies to Keras 3 (and if it does, we'll fix it). So remove this file and don't clone the initializers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this being used everywhere in KerasHub though. Removing it for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh really? Is this something that carried over from Keras 2 and Tensorflow? Or does it still apply to Keras 3 (and if it does, is it for every backend)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hertschuh - looks like it's true for all backends. Here's a short example for JAX: https://colab.research.google.com/drive/1oY7VfaFMztOoMgueOf2TuCz5lYRQ1GXs?resourcekey=0-wnk2cldy6PkkC2qV5PVDmg&usp=sharing.

This means we need clone_initializer.

Comment on lines 211 to 213
# Typecast config to `dict`. This is not really needed,
# but `typing` throws an error if we don't do this.
config = dict(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. that's the issue with using type annotation when Keras doesn't have them.

This works and I think it's a bit more elegant:

    def get_config(self) -> dict[str, Any]:
        config: dict[str, Any] = super().get_config()
        config.update({
            ...
        })
        return config


import keras
import numpy as np
import tensorflow as tf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, we don't want to do that, this should stay backend independent.


from keras_rs.src import types


class TestCase(unittest.TestCase):
class TestCase(tf.test.TestCase, unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you need to add tf.test.TestCase? Is it for get_temp_dir()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. If we want to avoid using tf.test.TestCase, I can directly use tempfile

Comment on lines 63 to 64
cls: Any,
init_kwargs: Dict[Any, Any],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All you do is model = cls(**init_kwargs), so just pass the model instead as a single parameter. It's just easier and more flexible.


self.assertAllClose(self.x, output)

def test_saved_model(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call it test_model_saving, saved_model is a specific format, but that's not the one used here.

@abheesht17 abheesht17 requested a review from hertschuh January 24, 2025 03:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants