diff --git a/test/data/test_multi_embedding_tensor.py b/test/data/test_multi_embedding_tensor.py index 2752bc0c2..5b67407a6 100644 --- a/test/data/test_multi_embedding_tensor.py +++ b/test/data/test_multi_embedding_tensor.py @@ -476,3 +476,15 @@ def test_cat(device): # case: list of non-MultiEmbeddingTensor should raise error with pytest.raises(AssertionError): MultiEmbeddingTensor.cat([object()], dim=0) + + +def test_pin_memory(): + met, _ = get_fake_multi_embedding_tensor( + num_rows=2, + num_cols=3, + ) + assert not met.values.is_pinned() + assert not met.offset.is_pinned() + met = met.pin_memory() + assert met.values.is_pinned() + assert met.offset.is_pinned() diff --git a/test/data/test_multi_nested_tensor.py b/test/data/test_multi_nested_tensor.py index 62594ac7e..166336d3e 100644 --- a/test/data/test_multi_nested_tensor.py +++ b/test/data/test_multi_nested_tensor.py @@ -87,7 +87,7 @@ def test_fillna_col(): @withCUDA -def test_multi_nested_tensor_basics(device): +def test_basics(device): num_rows = 8 num_cols = 10 max_value = 100 @@ -317,7 +317,7 @@ def test_multi_nested_tensor_basics(device): cloned_multi_nested_tensor) -def test_multi_nested_tensor_different_num_rows(): +def test_different_num_rows(): tensor_mat = [ [torch.tensor([1, 2, 3]), torch.tensor([4, 5])], @@ -331,3 +331,17 @@ def test_multi_nested_tensor_different_num_rows(): match="The length of each row must be the same", ): MultiNestedTensor.from_tensor_mat(tensor_mat) + + +def test_pin_memory(): + num_rows = 10 + num_cols = 3 + tensor = MultiNestedTensor.from_tensor_mat( + [[torch.randn(random.randint(0, 10)) for _ in range(num_cols)] + for _ in range(num_rows)]) + + assert not tensor.values.is_pinned() + assert not tensor.offset.is_pinned() + tensor = tensor.pin_memory() + assert tensor.values.is_pinned() + assert tensor.offset.is_pinned()