Skip to content

Commit

Permalink
feat(similarity): add support for sign sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitMY committed Feb 10, 2024
1 parent 29a1470 commit 4b6681a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
26 changes: 24 additions & 2 deletions signwriting_evaluation/metrics/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial import distance as dis
from signwriting.types import Sign, SignSymbol
from signwriting.formats.fsw_to_sign import fsw_to_sign
from signwriting.tokenizer import normalize_signwriting
from signwriting.types import Sign, SignSymbol

from signwriting_evaluation.metrics.base import SignWritingMetric


Expand Down Expand Up @@ -95,8 +97,28 @@ def error_rate(self, hyp: Sign, ref: Sign) -> float:
length_weight = pow(length_error, self.weight["exp_factor"])
return length_weight + mean_cost * (1 - length_weight)

def score(self, hypothesis: str, reference: str) -> float:
def score_single_sign(self, hypothesis: str, reference: str) -> float:
# Calculate the evaluate score for a given hypothesis and ref.
hyp = fsw_to_sign(hypothesis)
ref = fsw_to_sign(reference)
return pow(1 - self.error_rate(hyp, ref), 2)

def score(self, hypothesis: str, reference: str) -> float:
# Here, hypothesis and reference are both FSW strings of potentially different number of signs
hypothesis_signs = normalize_signwriting(hypothesis).split(" ")
reference_signs = normalize_signwriting(reference).split(" ")
if len(hypothesis_signs) == 1 and len(reference_signs) == 1:
return self.score_single_sign(hypothesis, reference)

# Pad with empty strings to make sure the number of signs is the same
if len(hypothesis_signs) != len(reference_signs):
max_length = max(len(hypothesis_signs), len(reference_signs))
hypothesis_signs += [""] * (max_length - len(hypothesis_signs))
reference_signs += [""] * (max_length - len(reference_signs))

# Match each hypothesis sign with each reference sign
cost_matrix = self.score_all(hypothesis_signs, reference_signs)
row_ind, col_ind = linear_sum_assignment(cost_matrix)
pairs = list(zip(row_ind, col_ind))
values = [cost_matrix[row][col] for row, col in pairs]
return sum(values) / len(values)
9 changes: 9 additions & 0 deletions signwriting_evaluation/metrics/test_similarity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest

from signwriting_evaluation.metrics.similarity import SignWritingSimilarityMetric


Expand Down Expand Up @@ -35,6 +36,14 @@ def test_corpus_score(self):
self.assertIsInstance(score, float)
self.assertAlmostEqual(score, 0.8326259781509948)

def test_multi_sign_score(self):
hypothesis_single = "M530x538S17600508x462S15a11493x494S20e00488x510S22f03469x517"
hypothesis = f"{hypothesis_single} {hypothesis_single}"
reference = "M530x538S17600508x462S12a11493x494S20e00488x510S22f13469x517"
score = self.metric.score(hypothesis, reference)
self.assertIsInstance(score, float)
self.assertAlmostEqual(score, 0.8326259781509948 / 2)

def test_bad_fsw_equals_0(self):
bad_fsw = "M<s><s>M<s>p483"
score = self.metric.corpus_score([bad_fsw], [[bad_fsw]])
Expand Down

0 comments on commit 4b6681a

Please sign in to comment.