Skip to content

Commit

Permalink
vmsdk/python/tests: add tests for TDX
Browse files Browse the repository at this point in the history
Signed-off-by: zhongjie <[email protected]>
  • Loading branch information
intelzhongjie committed Jan 10, 2024
1 parent 904333a commit afe6941
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 2 deletions.
3 changes: 2 additions & 1 deletion vmsdk/python/cc_imr_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(CCTrustedVmSdk.inst().get_measurement_count()):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
imr = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
digest_obj = imr.digest(alg.alg_id)

hash_str = ""
for hash_item in digest_obj.hash:
Expand Down
2 changes: 1 addition & 1 deletion vmsdk/python/cctrusted_vm/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
if algo_id is None or algo_id is TcgAlgorithmRegistry.TPM_ALG_ERROR:
algo_id = self._cvm.default_algo_id

return self._cvm.imrs[imr_index].digest(algo_id)
return self._cvm.imrs[imr_index]

def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
"""Get the quote for given nonce and data.
Expand Down
115 changes: 115 additions & 0 deletions vmsdk/python/tests/tdx/test_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""TDX specific test."""

import logging
import re
from hashlib import sha384
from tests import test_sdk
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent
from cctrusted_vm.sdk import CCTrustedVmSdk

LOG = logging.getLogger(__name__)

class TestCCTrustedVmSdkTdx(test_sdk.TestCCTrustedVmSdk):
"""TDX specific test."""

MEASUREMENT_COUNT = 4
"""Intel TDX TDREPORT provides the 4 measurement registers by default."""

PROC_CMDLINE= "/proc/cmdline"
"""Linux kernel cmdline arguments."""

IMA_SUPPORTED_POLICIES = ["critical_data", "tcb", "fail_securely", ""]
"""IMA supported policies.
When the policy is not specified or empty ("ima_policy="), the kernel
will measure boot_aggregate by default.
"""

RTMR_LENGTH_BY_BYTES = 48
"""RTMR length by bytes."""

def test_get_default_algorithms(self):
"""Test default algorithm is supported."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None
assert algo.alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384

def test_get_measurement_count(self):
"""Test measurement count is 4 (RTMR count)."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count == TestCCTrustedVmSdkTdx.MEASUREMENT_COUNT

def replay_eventlog(self):
"""Get RTMRs from event log by replay."""
rtmr_len = TestCCTrustedVmSdkTdx.RTMR_LENGTH_BY_BYTES
rtmr_cnt = TestCCTrustedVmSdkTdx.MEASUREMENT_COUNT
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs
assert event_logs is not None
for event in event_logs:
if isinstance(event, TcgImrEvent):
sha384_algo = sha384()
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash)
rtmrs[event.imr_index] = sha384_algo.digest()
return rtmrs

def check_imr(self, imr_index: int, alg_id: int, rtmr: bytes):
"""Check individual IMR.
Compare the 4 IMR hash with the hash derived by replay event log. They are
expected to be same.
Args:
imr_index: an integer specified the IMR index.
alg_id: an integer specified the hash algorithm.
rtmr: bytes of RTMR data for comparison.
"""
assert 0 <= imr_index < TestCCTrustedVmSdkTdx.MEASUREMENT_COUNT
assert rtmr is not None
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id])
assert imr is not None
digest_obj = imr.digest(alg_id)
assert digest_obj is not None
digest_alg_id = digest_obj.alg.alg_id
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
digest_hash = digest_obj.hash
assert digest_hash is not None
assert digest_hash == rtmr, \
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}"

def test_get_measurement(self):
"""Test measurement result.
The test is done by compare the measurement register against the value
derived by replay eventlog.
"""
cmdline = None
try:
with open(TestCCTrustedVmSdkTdx.PROC_CMDLINE, "r", encoding="utf-8") as proc_cmdline:
cmdline = proc_cmdline.readline()
except (PermissionError, OSError):
LOG.error("Need root permission to open file %s", TestCCTrustedVmSdkTdx.PROC_CMDLINE)
assert False
assert cmdline is not None

ima_policy = None
if "ima_hash=sha384" in cmdline:
m = re.search(r".*ima_policy=(\w+).*", cmdline)
if m is not None:
ima_policy = m.group(1)
LOG.info("ima_policy is %s", ima_policy)
if (ima_policy is None or
ima_policy in TestCCTrustedVmSdkTdx.IMA_SUPPORTED_POLICIES):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
rtmrs = self.replay_eventlog()
self.check_imr(0, alg.alg_id, rtmrs[0])
self.check_imr(1, alg.alg_id, rtmrs[1])
self.check_imr(2, alg.alg_id, rtmrs[2])
self.check_imr(3, alg.alg_id, rtmrs[3])
else:
LOG.info("IMA is not enabled! Unable to compare with event log!")

def test_get_eventlog(self):
"""Test get_eventlog result."""
#TODO: verify the eventlog value.

def test_get_quote(self):
"""Test get_quote result."""
#TODO: verify the quote value.

0 comments on commit afe6941

Please sign in to comment.