Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Aug 17, 2024
1 parent 0d9bd6f commit 1d897d0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
12 changes: 12 additions & 0 deletions test/data/test_multi_embedding_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,15 @@ def test_cat(device):
# case: list of non-MultiEmbeddingTensor should raise error
with pytest.raises(AssertionError):
MultiEmbeddingTensor.cat([object()], dim=0)


def test_pin_memory():
met, _ = get_fake_multi_embedding_tensor(
num_rows=2,
num_cols=3,
)
assert not met.values.is_pinned()
assert not met.offset.is_pinned()
met = met.pin_memory()
assert met.values.is_pinned()
assert met.offset.is_pinned()
18 changes: 16 additions & 2 deletions test/data/test_multi_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_fillna_col():


@withCUDA
def test_multi_nested_tensor_basics(device):
def test_basics(device):
num_rows = 8
num_cols = 10
max_value = 100
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_multi_nested_tensor_basics(device):
cloned_multi_nested_tensor)


def test_multi_nested_tensor_different_num_rows():
def test_different_num_rows():
tensor_mat = [
[torch.tensor([1, 2, 3]),
torch.tensor([4, 5])],
Expand All @@ -331,3 +331,17 @@ def test_multi_nested_tensor_different_num_rows():
match="The length of each row must be the same",
):
MultiNestedTensor.from_tensor_mat(tensor_mat)


def test_pin_memory():
num_rows = 10
num_cols = 3
tensor = MultiNestedTensor.from_tensor_mat(
[[torch.randn(random.randint(0, 10)) for _ in range(num_cols)]
for _ in range(num_rows)])

assert not tensor.values.is_pinned()
assert not tensor.offset.is_pinned()
tensor = tensor.pin_memory()
assert tensor.values.is_pinned()
assert tensor.offset.is_pinned()

0 comments on commit 1d897d0

Please sign in to comment.