Skip to content

Commit

Permalink
added functionality to inject knn and snn
Browse files Browse the repository at this point in the history
  • Loading branch information
hj-n committed Apr 15, 2023
1 parent 204abd5 commit 3c7c3de
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
40 changes: 32 additions & 8 deletions src/snc/helpers/hparam_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
e.g., knn info, distance matrix...
'''

def get_euclidean_infos(raw, emb, dist_parameter, dist_function, length, k):
def get_euclidean_infos(raw, emb, dist_parameter, dist_function, length, k, snn_knn_matrix):
raw_dist_matrix = dm.dist_matrix_gpu(raw)
emb_dist_matrix = dm.dist_matrix_gpu(emb)

Expand All @@ -35,7 +35,7 @@ def get_euclidean_infos(raw, emb, dist_parameter, dist_function, length, k):
"emb_knn" : emb_knn_info
}

def get_predefined_infos(raw, emb, dist_parameter, dist_function, length, k):
def get_predefined_infos(raw, emb, dist_parameter, dist_function, length, k, snn_knn_matrix):
raw_dist_matrix = np.zeros((length, length))
emb_dist_matrix = np.zeros((length, length))

Expand Down Expand Up @@ -65,9 +65,9 @@ def get_predefined_infos(raw, emb, dist_parameter, dist_function, length, k):



def get_snn_infos(raw, emb, dist_parameter, dist_function, length, k):
def get_snn_infos(raw, emb, dist_parameter, dist_function, length, k, snn_knn_matrix):

infos = get_euclidean_infos(raw, emb, dist_parameter, dist_function, length, k)
infos = get_euclidean_infos(raw, emb, dist_parameter, dist_function, length, k, snn_knn_matrix)

# Compute snn matrix
raw_snn_matrix = sk.snn_gpu(infos["raw_knn"], length, k)
Expand All @@ -91,6 +91,24 @@ def get_snn_infos(raw, emb, dist_parameter, dist_function, length, k):

return infos

def get_inject_snn_infos(raw, emb, dist_parameter, dist_function, length, k, snn_knn_matrix):
infos = {}
infos["raw_knn"] = snn_knn_matrix["raw_knn"]
infos["emb_knn"] = snn_knn_matrix["emb_knn"]

infos["raw_snn_matrix"] = snn_knn_matrix["raw_snn"]
infos["emb_snn_matrix"] = snn_knn_matrix["emb_snn"]

raw_snn_max = np.max(infos["raw_snn_matrix"])
emb_snn_max = np.max(infos["emb_snn_matrix"])

infos["raw_snn_matrix"] /= raw_snn_max
infos["emb_snn_matrix"] /= emb_snn_max

infos["raw_dist_matrix"] = 1 / (infos["raw_snn_matrix"] + dist_parameter["alpha"])
infos["emb_dist_matrix"] = 1 / (infos["emb_snn_matrix"] + dist_parameter["alpha"])

return infos

'''
Helper functions to extract a cluster
Expand Down Expand Up @@ -200,7 +218,7 @@ def get_predefined_cluster_distance(cluster_a, cluster_b, raw, emb, infos, dist_
INSTALLING Hyperparameter functions
'''

def install_hparam(dist_strategy, dist_parameter, dist_function, cluster_strategy, raw, emb):
def install_hparam(dist_strategy, dist_parameter, dist_function, cluster_strategy, snn_knn_matrix, raw, emb):
get_infos = None
get_a_cluster = None
get_clusterinng = None
Expand All @@ -219,6 +237,10 @@ def install_hparam(dist_strategy, dist_parameter, dist_function, cluster_strateg
get_infos = get_predefined_infos
get_a_cluster = get_a_cluster_naive
get_cluster_distance = get_predefined_cluster_distance
elif dist_strategy == "inject_snn":
get_infos = get_inject_snn_infos
get_a_cluster = get_a_cluster_snn
get_cluster_distance = get_snn_cluster_distance
else:
raise Exception("Wrong strategy choice!! check dist_strategy ('" + dist_strategy + "')")

Expand All @@ -238,30 +260,32 @@ def install_hparam(dist_strategy, dist_parameter, dist_function, cluster_strateg

return HparamFunctions(
raw, emb, dist_parameter, dist_function,
get_infos, get_a_cluster, get_clustering, get_cluster_distance
get_infos, get_a_cluster, get_clustering, get_cluster_distance, snn_knn_matrix
)


class HparamFunctions():
'''
Saving raw, emb info and setting parameter
'''
def __init__(self, raw, emb, dist_parameter, dist_function, get_infos, get_a_cluster, get_clustering, get_cluster_distance):
def __init__(self, raw, emb, dist_parameter, dist_function, get_infos, get_a_cluster, get_clustering, get_cluster_distance, snn_knn_matrix):
self.raw = raw
self.emb = emb
self.length = len(self.raw)
self.dist_parameter = dist_parameter
self.dist_function = dist_function
self.k = dist_parameter["k"]
self.snn_knn_matrix = snn_knn_matrix

## Inject functions
self.get_infos = get_infos
self.get_a_cluster = get_a_cluster
self.get_clustering = get_clustering
self.get_cluster_distance = get_cluster_distance


def preprocessing(self):
self.infos = self.get_infos(self.raw, self.emb, self.dist_parameter, self.dist_function, self.length, self.k)
self.infos = self.get_infos(self.raw, self.emb, self.dist_parameter, self.dist_function, self.length, self.k, self.snn_knn_matrix)
dissim_matrix = self.infos["raw_dist_matrix"] - self.infos["emb_dist_matrix"]

dissim_max = np.max(dissim_matrix)
Expand Down
6 changes: 4 additions & 2 deletions src/snc/snc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(
"alpha": 0.1, "k": "sqrt"
},
dist_function=None, # inject predefined distance function
cluster_strategy="dbscan" # determines the way to consider clusters
cluster_strategy="dbscan", # determines the way to consider clusters
snn_knn_matrix=None, # inject predefined similarity matrix (dist_strategy should be "inject_snn")
):
self.raw = np.array(raw, dtype=np.float64)
self.emb = np.array(emb, dtype=np.float64)
Expand All @@ -32,6 +33,7 @@ def __init__(
self.dist_parameter = dist_parameter
self.dist_function = dist_function
self.cluster_strategy = cluster_strategy
self.snn_knn_matrix = snn_knn_matrix

## target score
self.cohev_score = None
Expand Down Expand Up @@ -63,7 +65,7 @@ def fit(self, record_vis_info=False):

self.cstrat = hp.install_hparam(
self.dist_strategy, self.dist_parameter, self.dist_function,
self.cluster_strategy,
self.cluster_strategy, self.snn_knn_matrix,
self.raw, self.emb
)

Expand Down

0 comments on commit 3c7c3de

Please sign in to comment.