Skip to content

Commit

Permalink
vmsdk/python/tests: add tests for TDX
Browse files Browse the repository at this point in the history
This patch mainly adds some tests for TDX.
And it refactors some corresponding code accordingly.

Signed-off-by: zhongjie <[email protected]>
  • Loading branch information
intelzhongjie committed Jan 12, 2024
1 parent 904333a commit ff0b6ab
Show file tree
Hide file tree
Showing 13 changed files with 253 additions and 77 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/vmsdk-test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ jobs:
- name: Run PyTest for VMSDK
run: |
set -ex
sudo su -c "source setupenv.sh && python3 -m pytest -v ./vmsdk/python/tests/test_sdk.py"
# Set PYTHONDONTWRITEBYTECODE and --no-cacheprovider to prevent
# generated some intermediate files by root. Othwerwise, these
# files will fail the action/checkout in the next round of running
# due to the permission issue.
sudo su -c "source setupenv.sh && \
pushd vmsdk/python/tests && \
export PYTHONDONTWRITEBYTECODE=1 && \
python3 ./run.py --no-cacheprovider True && \
popd"
7 changes: 6 additions & 1 deletion common/python/cctrusted_base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
raise NotImplementedError("Inherited SDK class should implement this.")

@abstractmethod
def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
20 changes: 1 addition & 19 deletions common/python/cctrusted_base/imr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from abc import ABC, abstractmethod
from cctrusted_base.tcg import TcgDigest, TcgAlgorithmRegistry
from cctrusted_base.tcg import TcgDigest

class TcgIMR(ABC):
"""Common Integrated Measurement Register class."""
Expand Down Expand Up @@ -56,21 +56,3 @@ def is_valid(self):
"""
return self._index != TcgIMR._INVALID_IMR_INDEX and \
self._index <= self.max_index

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
23 changes: 23 additions & 0 deletions common/python/cctrusted_base/tdx/rtmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
RTMR (Runtime Measurement Register).
"""

from cctrusted_base.imr import TcgIMR
from cctrusted_base.tcg import TcgAlgorithmRegistry

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

RTMR_COUNT = 4
"""Intel TDX TDREPORT provides the 4 measurement registers by default."""

RTMR_LENGTH_BY_BYTES = 48
"""RTMR length by bytes."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)
12 changes: 12 additions & 0 deletions common/python/cctrusted_base/tpm/pcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
PCR (Platform Configuration Register).
"""

from cctrusted_base.imr import TcgIMR

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
3 changes: 2 additions & 1 deletion vmsdk/python/cc_imr_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(CCTrustedVmSdk.inst().get_measurement_count()):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
imr = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
digest_obj = imr.digest(alg.alg_id)

hash_str = ""
for hash_item in digest_obj.hash:
Expand Down
2 changes: 1 addition & 1 deletion vmsdk/python/cc_quote_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
level=logging.NOTSET,
format="%(name)s %(levelname)-8s %(message)s"
)
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
quote = CCTrustedVmSdk.inst().get_quote()
if quote is not None:
quote.dump(args.out_format == OUT_FORMAT_RAW)
else:
Expand Down
3 changes: 2 additions & 1 deletion vmsdk/python/cctrusted_vm/cvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import struct
import fcntl
from abc import abstractmethod
from cctrusted_base.imr import TdxRTMR,TcgIMR
from cctrusted_base.imr import TcgIMR
from cctrusted_base.quote import Quote
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.common import TDX_VERSION_1_0, TDX_VERSION_1_5
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15
from cctrusted_base.tdx.report import TdxReportReq10, TdxReportReq15

Expand Down
11 changes: 8 additions & 3 deletions vmsdk/python/cctrusted_vm/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
if algo_id is None or algo_id is TcgAlgorithmRegistry.TPM_ALG_ERROR:
algo_id = self._cvm.default_algo_id

return self._cvm.imrs[imr_index].digest(algo_id)

def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
return self._cvm.imrs[imr_index]

def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
4 changes: 4 additions & 0 deletions vmsdk/python/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
markers =
basic: Select the test functions for basic testing
tdx: Select the test functions for TDX testing
38 changes: 38 additions & 0 deletions vmsdk/python/tests/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Run tests for VMSDK.
"""

import argparse
import logging
import pytest
from cctrusted_vm.cvm import ConfidentialVM

LOG = logging.getLogger(__name__)

def main():
"""Run tests for VMSDK."""
parser = argparse.ArgumentParser(description="Run tests for VMSDK")
parser.add_argument(
"--no-cacheprovider",
default=False,
dest="no_cacheprovider",
help="No cacheprovider for pytest. True/False",
type=lambda x: (str(x).lower() in ['true','1', 'yes'])
)
args = parser.parse_args()
pytest_options = ["-v", "./"]
if args.no_cacheprovider is True:
pytest_options += ["-p", "no:cacheprovider"]

cc_type = ConfidentialVM.detect_cc_type()
LOG.info("CC type is %s.", ConfidentialVM.TYPE_CC_STRING[cc_type])

LOG.info("Run basic test ...")
pytest.main(pytest_options + ["-m", "basic"])

if cc_type is ConfidentialVM.TYPE_CC_TDX:
LOG.info("Run TDX specific test ...")
pytest.main(pytest_options + ["-m", "tdx"])

if __name__ == "__main__":
main()
104 changes: 54 additions & 50 deletions vmsdk/python/tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,57 @@
import pytest
from cctrusted_vm import CCTrustedVmSdk

class TestCCTrustedVmSdk():
"""Unit tests for CCTrustedVmSdk class."""

def test_get_default_algorithms(self):
"""Test get_default_algorithms() function."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None

def test_get_measurement_count(self):
"""Test get_measurement_count() function."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count is not None

def test_get_measurement_with_invalid_input(self):
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = CCTrustedVmSdk.inst().get_measurement([0, None])
assert measurement is not None

def test_get_measurement_with_valid_input(self):
"""Test get_measurement() function with valid input."""
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(count):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
assert digest_obj is not None

def test_get_eventlog_with_invalid_input(self):
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=0)

def test_get_eventlog_with_valid_input(self):
"""Test get_eventlog() funtion with valid input."""
event_logs = CCTrustedVmSdk.inst().get_eventlog()
assert event_logs is not None

def test_get_quote_with_valid_input(self):
"""Test get_quote() function with valid input."""
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
assert quote is not None
@pytest.mark.basic
def test_get_default_algorithms():
"""Test get_default_algorithms() function."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None

@pytest.mark.basic
def test_get_measurement_count():
"""Test get_measurement_count() function."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count is not None

@pytest.mark.basic
def test_get_measurement_with_invalid_input():
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = CCTrustedVmSdk.inst().get_measurement([0, None])
assert measurement is not None

@pytest.mark.basic
def test_get_measurement_with_valid_input():
"""Test get_measurement() function with valid input."""
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(count):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
assert digest_obj is not None

@pytest.mark.basic
def test_get_eventlog_with_invalid_input():
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=0)

@pytest.mark.basic
def test_get_eventlog_with_valid_input():
"""Test get_eventlog() funtion with valid input."""
event_logs = CCTrustedVmSdk.inst().get_eventlog()
assert event_logs is not None

@pytest.mark.basic
def test_get_quote_with_valid_input():
"""Test get_quote() function with valid input."""
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
assert quote is not None
93 changes: 93 additions & 0 deletions vmsdk/python/tests/test_sdk_tdx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""TDX specific test."""

from hashlib import sha384
import pytest
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent
from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.sdk import CCTrustedVmSdk

@pytest.mark.tdx
def test_tdx_get_default_algorithms():
"""Test default algorithm is supported."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None
assert algo.alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384

@pytest.mark.tdx
def test_tdx_get_measurement_count():
"""Test measurement count is 4 (RTMR count)."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count == TdxRTMR.RTMR_COUNT

def _replay_eventlog():
"""Get RTMRs from event log by replay."""
rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES
rtmr_cnt = TdxRTMR.RTMR_COUNT
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs
assert event_logs is not None
for event in event_logs:
if isinstance(event, TcgImrEvent):
sha384_algo = sha384()
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash)
rtmrs[event.imr_index] = sha384_algo.digest()
return rtmrs

def _check_imr(imr_index: int, alg_id: int, rtmr: bytes):
"""Check individual IMR.
Compare the 4 IMR hash with the hash derived by replay event log. They are
expected to be same.
Args:
imr_index: an integer specified the IMR index.
alg_id: an integer specified the hash algorithm.
rtmr: bytes of RTMR data for comparison.
"""
assert 0 <= imr_index < TdxRTMR.RTMR_COUNT
assert rtmr is not None
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id])
assert imr is not None
digest_obj = imr.digest(alg_id)
assert digest_obj is not None
digest_alg_id = digest_obj.alg.alg_id
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
digest_hash = digest_obj.hash
assert digest_hash is not None
assert digest_hash == rtmr, \
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}"

@pytest.mark.tdx
def test_tdx_get_measurement_imrs():
"""Test measurement result.
The test is done by compare the measurement register against the value
derived by replay eventlog.
"""
alg = CCTrustedVmSdk.inst().get_default_algorithms()
rtmrs = _replay_eventlog()
_check_imr(0, alg.alg_id, rtmrs[0])
_check_imr(1, alg.alg_id, rtmrs[1])
_check_imr(2, alg.alg_id, rtmrs[2])
_check_imr(3, alg.alg_id, rtmrs[3])

@pytest.mark.tdx
def test_tdx_get_quote_rtmrs():
"""Test quote result.
The test is done by compare the RTMRs in quote body against the value
derived by replay eventlog.
"""
quote = CCTrustedVmSdk.inst().get_quote()
assert quote is not None
assert isinstance(quote, TdxQuote)
body = quote.body
assert body is not None
assert isinstance(body, TdxQuoteBody)
rtmrs = _replay_eventlog()
assert body.rtmr0 == rtmrs[0], \
"RTMR0 doesn't equal the replay from event log!"
assert body.rtmr1 == rtmrs[1], \
"RTMR1 doesn't equal the replay from event log!"
assert body.rtmr2 == rtmrs[2], \
"RTMR2 doesn't equal the replay from event log!"
assert body.rtmr3 == rtmrs[3], \
"RTMR3 doesn't equal the replay from event log!"

0 comments on commit ff0b6ab

Please sign in to comment.