From 72daf0e6fca828d062c7c462739d30c17f5ce33e Mon Sep 17 00:00:00 2001 From: dnth Date: Fri, 8 Nov 2024 16:56:52 +0800 Subject: [PATCH] update florence test --- tests/smoke/test_florence2.py | 25 +++++++++++++++++-------- xinfer/transformers/florence2.py | 6 ++---- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/smoke/test_florence2.py b/tests/smoke/test_florence2.py index 6471b09..7894d7e 100644 --- a/tests/smoke/test_florence2.py +++ b/tests/smoke/test_florence2.py @@ -5,6 +5,8 @@ import xinfer +TEST_DATA_DIR = Path(__file__).parent.parent / "test_data" + @pytest.fixture def model(): @@ -14,8 +16,11 @@ def model(): @pytest.fixture -def test_image(): - return str(Path(__file__).parent.parent / "test_data" / "test_image_1.jpg") +def test_images(): + return [ + str(TEST_DATA_DIR / "test_image_1.jpg"), + str(TEST_DATA_DIR / "test_image_2.jpg"), + ] def test_florence2_initialization(model): @@ -24,17 +29,21 @@ def test_florence2_initialization(model): assert model.dtype == torch.float32 -def test_florence2_inference(model, test_image): +def test_florence2_inference(model, test_images): prompt = "" - result = model.infer(test_image, prompt) + result = model.infer(test_images[0], prompt) - assert isinstance(result, str) - assert len(result) > 0 + assert isinstance(result.text, str) + assert len(result.text) > 0 -def test_florence2_batch_inference(model, test_image): +def test_florence2_batch_inference(model, test_images): prompt = "" - result = model.infer_batch([test_image, test_image], [prompt, prompt]) + result = model.infer_batch(test_images, [prompt, prompt]) assert isinstance(result, list) assert len(result) == 2 + assert isinstance(result[0].text, str) + assert isinstance(result[1].text, str) + assert len(result[0].text) > 0 + assert len(result[1].text) > 0 diff --git a/xinfer/transformers/florence2.py b/xinfer/transformers/florence2.py index b978b43..d19f5a2 100644 --- a/xinfer/transformers/florence2.py +++ b/xinfer/transformers/florence2.py @@ -47,7 +47,7 @@ def load_model(self): @track_inference def infer(self, image: str, text: str, **generate_kwargs) -> Result: output = self.infer_batch([image], [text], **generate_kwargs) - return Result(text=output[0]) + return output[0] @track_inference def infer_batch( @@ -81,6 +81,4 @@ def infer_batch( for text, prompt, img in zip(generated_text, texts, images) ] - results = [Result(text=text) for text in parsed_answers] - - return results + return [Result(text=text) for text in parsed_answers]