Skip to content

Commit

Permalink
fix(clip): allow bad fsw strings
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitMY committed Jan 25, 2024
1 parent affd85b commit d10c1fb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion signwriting_evaluation/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def corpus_score(self, hypotheses: list[str], references: list[list[str]]) -> fl

def score_all(self, hypotheses: list[str], references: list[str]) -> list[list[float]]:
# Default implementation: call the score function for each hypothesis-reference pair
return [[self.score(h, r) for r in references] for h in tqdm(hypotheses)]
return [[self.score(h, r) for r in references] for h in tqdm(hypotheses, disable=len(hypotheses) == 1)]

def __str__(self):
return self.name
16 changes: 13 additions & 3 deletions signwriting_evaluation/metrics/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,19 @@


def signwriting_to_clip_image(signwriting: CLIPInput, size=224) -> Image:
img = signwriting_to_image(signwriting) if isinstance(signwriting, str) else signwriting
new_img = Image.new('RGB', (size, size), (255, 255, 255))

if isinstance(signwriting, str):
try:
img = signwriting_to_image(signwriting)
except ValueError as value_error:
# This may happen when the M box maximum values are lower
# than the symbols minimum values
print(value_error)
return new_img
else:
img = signwriting

if img.width > size or img.height > size:
return new_img

Expand Down Expand Up @@ -99,7 +109,7 @@ def get_clip_features(self, inputs: list[CLIPInput]):
missing = [clip_input for clip_input in inputs if self.cache_name(clip_input) not in self.cached_texts]

if len(missing) > 0:
pbar_disable = len(missing) < self.batch_size
pbar_disable = len(missing) <= self.batch_size
pbar = tqdm(total=len(inputs), initial=len(inputs) - len(missing),
desc="Computing CLIP features", disable=pbar_disable)

Expand All @@ -112,7 +122,7 @@ def get_clip_features(self, inputs: list[CLIPInput]):

pbar.close()

texts = tqdm(inputs, desc="Loading features cache", disable=len(inputs) < self.batch_size)
texts = tqdm(inputs, desc="Loading features cache", disable=len(inputs) <= self.batch_size)
cached_features = [self.cache[self.cache_name(text)].cpu() for text in texts]
features = torch.stack(cached_features)

Expand Down
7 changes: 6 additions & 1 deletion signwriting_evaluation/metrics/test_clip.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import unittest

import numpy as np
from PIL import Image

from signwriting_evaluation.metrics.clip import SignWritingCLIPScore
from signwriting_evaluation.metrics.clip import SignWritingCLIPScore, signwriting_to_clip_image


class TestSignWritingCLIPScore(unittest.TestCase):
Expand All @@ -23,6 +24,10 @@ def test_score_image(self):
self.assertIsInstance(score, float) # Check if the score is a float
self.assertAlmostEqual(score, 0.7759, places=2)

def test_bad_fsw_is_empty_image(self):
fsw = "M530x538S37602531x539"
image = signwriting_to_clip_image(fsw)
self.assertTrue(np.alltrue(np.array(image) == 255))

if __name__ == '__main__':
unittest.main()

0 comments on commit d10c1fb

Please sign in to comment.