Skip to content

Commit

Permalink
Improve: Fewer PyTest runs
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Nov 26, 2024
1 parent 08c7ac0 commit 6effbea
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def test_invalid_argument_handling(function, expected_error, args, kwargs):


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16"])
@pytest.mark.parametrize("metric", ["inner", "euclidean", "sqeuclidean", "cosine"])
Expand Down Expand Up @@ -617,7 +617,7 @@ def test_dense(ndim, dtype, metric, capability, stats_fixture):


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97])
@pytest.mark.parametrize(
"dtypes", # representation datatype and compute precision
Expand Down Expand Up @@ -674,7 +674,7 @@ def test_curved(ndim, dtypes, metric, capability, stats_fixture):

@pytest.mark.skipif(is_running_under_qemu(), reason="Complex math in QEMU fails")
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97])
@pytest.mark.parametrize("dtype", ["complex128", "complex64"])
@pytest.mark.parametrize("capability", possible_capabilities)
Expand Down Expand Up @@ -707,7 +707,7 @@ def test_curved_complex(ndim, dtype, capability, stats_fixture):


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("metric", ["inner", "euclidean", "sqeuclidean", "cosine"])
@pytest.mark.parametrize("capability", possible_capabilities)
Expand Down Expand Up @@ -747,7 +747,7 @@ def test_dense_bf16(ndim, metric, capability, stats_fixture):


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 16, 33])
@pytest.mark.parametrize("metric", ["bilinear", "mahalanobis"])
@pytest.mark.parametrize("capability", possible_capabilities)
Expand Down Expand Up @@ -806,7 +806,7 @@ def test_curved_bf16(ndim, metric, capability, stats_fixture):


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
@pytest.mark.parametrize("metric", ["inner", "euclidean", "sqeuclidean", "cosine"])
Expand Down Expand Up @@ -852,7 +852,7 @@ def test_dense_i8(ndim, dtype, metric, capability, stats_fixture):

@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.skipif(not scipy_available, reason="SciPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("metric", ["jaccard", "hamming"])
@pytest.mark.parametrize("capability", possible_capabilities)
Expand Down Expand Up @@ -883,7 +883,7 @@ def test_dense_bits(ndim, metric, capability, stats_fixture):

@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.skipif(not scipy_available, reason="SciPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["float32", "float16"])
@pytest.mark.parametrize("capability", possible_capabilities)
Expand Down Expand Up @@ -937,7 +937,7 @@ def test_cosine_zero_vector(ndim, dtype, capability):

@pytest.mark.skip(reason="Lacks overflow protection: https://github.com/ashvardanian/SimSIMD/issues/206") # TODO
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16"])
@pytest.mark.parametrize("metric", ["inner", "euclidean", "sqeuclidean", "cosine"])
Expand Down Expand Up @@ -971,7 +971,7 @@ def test_overflow(ndim, dtype, metric, capability):

@pytest.mark.skip(reason="Lacks overflow protection: https://github.com/ashvardanian/SimSIMD/issues/206") # TODO
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [131072, 262144])
@pytest.mark.parametrize("metric", ["inner", "euclidean", "sqeuclidean", "cosine"])
@pytest.mark.parametrize("capability", possible_capabilities)
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def test_overflow_i8(ndim, metric, capability):

@pytest.mark.skipif(is_running_under_qemu(), reason="Complex math in QEMU fails")
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["complex128", "complex64"])
@pytest.mark.parametrize("capability", possible_capabilities)
Expand All @@ -1018,23 +1018,19 @@ def test_dot_complex(ndim, dtype, capability, stats_fixture):
result = np.array(result)

np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL)
collect_errors(
"dot", ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture
)
collect_errors("dot", ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture)

accurate_dt, accurate = profile(np.vdot, a.astype(np.complex128), b.astype(np.complex128))
expected_dt, expected = profile(np.vdot, a, b)
result_dt, result = profile(simd.vdot, a, b)
result = np.array(result)

np.testing.assert_allclose(result, expected, atol=SIMSIMD_ATOL, rtol=SIMSIMD_RTOL)
collect_errors(
"vdot", ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture
)
collect_errors("vdot", ndim, dtype, accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture)


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(100)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("dtype", ["uint16", "uint32"])
@pytest.mark.parametrize("first_length_bound", [10, 100, 1000])
@pytest.mark.parametrize("second_length_bound", [10, 100, 1000])
Expand Down Expand Up @@ -1064,7 +1060,7 @@ def test_intersect(dtype, first_length_bound, second_length_bound, capability):


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "int8", "uint8"])
@pytest.mark.parametrize("kernel", ["fma"])
Expand Down Expand Up @@ -1125,7 +1121,7 @@ def test_fma(ndim, dtype, kernel, capability, stats_fixture):


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "int8", "uint8"])
@pytest.mark.parametrize("kernel", ["wsum"])
Expand Down Expand Up @@ -1395,7 +1391,7 @@ def test_cdist_complex(ndim, input_dtype, out_dtype, metric, capability):

@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.skipif(not scipy_available, reason="SciPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.repeat(5)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("out_dtype", [None, "float32", "float16", "int8"])
@pytest.mark.parametrize("capability", possible_capabilities)
Expand Down

0 comments on commit 6effbea

Please sign in to comment.