diff --git a/src/defense/models.py b/src/defense/models.py index 6df6cb6..069d9ac 100644 --- a/src/defense/models.py +++ b/src/defense/models.py @@ -1,5 +1,6 @@ from typing import List from typings.models import Model +from utils.layers import SlqLayer import numpy as np import tensorflow as tf @@ -54,3 +55,17 @@ def predict(epoch, logs): ) return [keras.callbacks.LambdaCallback(on_epoch_end=predict)] + + +class Denoiser(Model): + def _model(self) -> keras.Model: + return keras.Sequential([SlqLayer()]) + + def pre_train(self): + pass + + def post_train(self): + pass + + def custom_callbacks(self) -> List[keras.callbacks.Callback]: + pass