diff --git a/tests/smoke/test_timm.py b/tests/smoke/test_timm.py index 3c10b4d..7becceb 100644 --- a/tests/smoke/test_timm.py +++ b/tests/smoke/test_timm.py @@ -9,7 +9,7 @@ @pytest.fixture def model(): - return xinfer.create_model("timm/resnet18.a1_in1k", device="cuda", dtype="float16") + return xinfer.create_model("timm/resnet18.a1_in1k") @pytest.fixture