Skip to content

Commit

Permalink
update test to check that stateful expected
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 14, 2025
1 parent 9219632 commit b12dca9
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from utils_tests import MODEL_NAMES, TEST_IMAGE_URL, mock_torch_cuda_is_available, patch_awq_for_inference

from optimum.exporters.openvino.model_patcher import patch_update_causal_mask
from optimum.exporters.openvino.stateful import model_has_state
from optimum.intel import (
OVDiffusionPipeline,
OVFluxPipeline,
Expand Down Expand Up @@ -1625,12 +1626,18 @@ class OVModelForSeq2SeqLMIntegrationTest(unittest.TestCase):
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.1

SUPPORT_STATEFUL = ("t5", "mt5")

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
ov_model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)

expected_stateful = is_transformers_version(">", "4.43") and model_arch in self.SUPPORT_STATEFUL
self.assertEqual(ov_model.decoder.stateful, expected_stateful)
self.assertEqual(model_has_state(ov_model.decoder_model), expected_stateful)
check_with_past_available = self.assertIsNone if expected_stateful else self.assertIsNotNone
check_with_past_available(ov_model.decoder_with_past)
self.assertIsInstance(ov_model.encoder, OVEncoder)
self.assertIsInstance(ov_model.decoder, OVDecoder)
if not ov_model.decoder.stateful:
Expand Down Expand Up @@ -2339,6 +2346,12 @@ def test_compare_to_transformers(self, model_arch):
transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
self.assertIsInstance(ov_model.config, PretrainedConfig)
# whisper cache class support implemented in 4.43
expected_stateful = is_transformers_version(">", "4.43")
self.assertEqual(ov_model.decoder.stateful, expected_stateful)
self.assertEqual(model_has_state(ov_model.decoder_model), expected_stateful)
check_with_past_available = self.assertIsNone if expected_stateful else self.assertIsNotNone
check_with_past_available(ov_model.decoder_with_past)

processor = get_preprocessor(model_id)
data = self._generate_random_audio_data()
Expand Down

0 comments on commit b12dca9

Please sign in to comment.