Skip to content

Commit

Permalink
fix pipeline test
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Jan 15, 2025
1 parent ee7dd81 commit c86fd1c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/ipex/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ def test_text_generation_pipeline_inference(self, model_arch):
"text-generation", model_id, accelerator="ipex", torch_dtype=dtype, device_map=DEVICE
)
inputs = "Describe a real-world application of AI."
max_new_tokens = 10 if model_arch != "qwen2" else 2
with torch.inference_mode():
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10)
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens)
with torch.inference_mode():
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10)
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens)
self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM))
self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"])

Expand Down

0 comments on commit c86fd1c

Please sign in to comment.