From c86fd1ca17191f44ba221c93208f631d405210c9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 15:37:29 +0000 Subject: [PATCH] fix pipeline test Signed-off-by: jiqing-feng --- tests/ipex/test_pipelines.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index bcdc59208..62c3877b5 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -145,10 +145,11 @@ def test_text_generation_pipeline_inference(self, model_arch): "text-generation", model_id, accelerator="ipex", torch_dtype=dtype, device_map=DEVICE ) inputs = "Describe a real-world application of AI." + max_new_tokens = 10 if model_arch != "qwen2" else 2 with torch.inference_mode(): - transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens) with torch.inference_mode(): - ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM)) self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"])