Skip to content

Commit

Permalink
vmsdk/python/tests: add tests for TDX (#57)
Browse files Browse the repository at this point in the history
* docs(contributor): contrib-readme-action has updated readme

* vmsdk/python/tests: add tests for TDX

This patch mainly adds some tests for TDX.
And it refactors some corresponding code accordingly.

Signed-off-by: zhongjie <[email protected]>

* docs(contributor): contrib-readme-action has updated readme

---------

Signed-off-by: zhongjie <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
intelzhongjie and github-actions[bot] authored Jan 17, 2024
1 parent 9638406 commit e3dcb82
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 79 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/vmsdk-test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ jobs:
- name: Run PyTest for VMSDK
run: |
set -ex
sudo su -c "source setupenv.sh && python3 -m pytest -v ./vmsdk/python/tests/test_sdk.py"
# Set the "PYTHONDONTWRITEBYTECODE" and "no:cacheprovider" to prevent
# generated some intermediate files by root. Othwerwise, these
# files will fail the action/checkout in the next round of running
# due to the permission issue.
sudo su -c "source setupenv.sh && \
pushd vmsdk/python/tests && \
export PYTHONDONTWRITEBYTECODE=1 && \
python3 -m pytest -p no:cacheprovider -v test_sdk.py && \
popd"
7 changes: 6 additions & 1 deletion common/python/cctrusted_base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
raise NotImplementedError("Inherited SDK class should implement this.")

@abstractmethod
def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
20 changes: 1 addition & 19 deletions common/python/cctrusted_base/imr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from abc import ABC, abstractmethod
from cctrusted_base.tcg import TcgDigest, TcgAlgorithmRegistry
from cctrusted_base.tcg import TcgDigest

class TcgIMR(ABC):
"""Common Integrated Measurement Register class."""
Expand Down Expand Up @@ -56,21 +56,3 @@ def is_valid(self):
"""
return self._index != TcgIMR._INVALID_IMR_INDEX and \
self._index <= self.max_index

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
23 changes: 23 additions & 0 deletions common/python/cctrusted_base/tdx/rtmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
RTMR (Runtime Measurement Register).
"""

from cctrusted_base.imr import TcgIMR
from cctrusted_base.tcg import TcgAlgorithmRegistry

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

RTMR_COUNT = 4
"""Intel TDX TDREPORT provides the 4 measurement registers by default."""

RTMR_LENGTH_BY_BYTES = 48
"""RTMR length by bytes."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)
12 changes: 12 additions & 0 deletions common/python/cctrusted_base/tpm/pcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
PCR (Platform Configuration Register).
"""

from cctrusted_base.imr import TcgIMR

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
3 changes: 2 additions & 1 deletion vmsdk/python/cc_imr_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(CCTrustedVmSdk.inst().get_measurement_count()):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
imr = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
digest_obj = imr.digest(alg.alg_id)

hash_str = ""
for hash_item in digest_obj.hash:
Expand Down
2 changes: 1 addition & 1 deletion vmsdk/python/cc_quote_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
level=logging.NOTSET,
format="%(name)s %(levelname)-8s %(message)s"
)
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
quote = CCTrustedVmSdk.inst().get_quote()
if quote is not None:
quote.dump(args.out_format == OUT_FORMAT_RAW)
else:
Expand Down
3 changes: 2 additions & 1 deletion vmsdk/python/cctrusted_vm/cvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import struct
import fcntl
from abc import abstractmethod
from cctrusted_base.imr import TdxRTMR,TcgIMR
from cctrusted_base.imr import TcgIMR
from cctrusted_base.quote import Quote
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.common import TDX_VERSION_1_0, TDX_VERSION_1_5
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15
from cctrusted_base.tdx.report import TdxReportReq10, TdxReportReq15

Expand Down
11 changes: 8 additions & 3 deletions vmsdk/python/cctrusted_vm/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
if algo_id is None or algo_id is TcgAlgorithmRegistry.TPM_ALG_ERROR:
algo_id = self._cvm.default_algo_id

return self._cvm.imrs[imr_index].digest(algo_id)

def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
return self._cvm.imrs[imr_index]

def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
66 changes: 66 additions & 0 deletions vmsdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Local conftest.py containing directory-specific hook implementations."""

import pytest
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.cvm import ConfidentialVM
from cctrusted_vm.sdk import CCTrustedVmSdk
import tdx_check

cnf_default_alg = {
ConfidentialVM.TYPE_CC_TDX: TcgAlgorithmRegistry.TPM_ALG_SHA384
}
"""Configurations of default algorithm.
The configurations could be different for different confidential VMs.
e.g. TDX use sha384 as the default.
"""

cnf_measurement_cnt = {
ConfidentialVM.TYPE_CC_TDX: TdxRTMR.RTMR_COUNT
}
"""Configurations of measurement count.
The configurations could be different for different confidential VMs.
"""

cnf_measurement_check = {
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_measurement_imrs
}
"""Configurations of measurement check functions.
The configurations could be different for different confidential VMs.
"""

cnf_quote_check = {
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_quote_rtmrs
}
"""Configurations of quote check functions.
The configurations could be different for different confidential VMs.
"""

@pytest.fixture(scope="module")
def vm_sdk():
"""Get VMSDK instance."""
return CCTrustedVmSdk.inst()

@pytest.fixture(scope="module")
def default_alg_id():
"""Get default algorithm."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_default_alg[cc_type]

@pytest.fixture(scope="module")
def measurement_count():
"""Get measurement count."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_measurement_cnt[cc_type]

@pytest.fixture(scope="module")
def check_measurement():
"""Return checker for measurement."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_measurement_check[cc_type]

@pytest.fixture(scope="module")
def check_quote():
"""Return checker for quote."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_quote_check[cc_type]
77 changes: 77 additions & 0 deletions vmsdk/python/tests/tdx_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""TDX specific test."""

from hashlib import sha384
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent
from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.sdk import CCTrustedVmSdk

def _replay_eventlog():
"""Get RTMRs from event log by replay."""
rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES
rtmr_cnt = TdxRTMR.RTMR_COUNT
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs
assert event_logs is not None
for event in event_logs:
if isinstance(event, TcgImrEvent):
sha384_algo = sha384()
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash)
rtmrs[event.imr_index] = sha384_algo.digest()
return rtmrs

def _check_imr(imr_index: int, alg_id: int, rtmr: bytes):
"""Check individual IMR.
Compare the 4 IMR hash with the hash derived by replay event log. They are
expected to be same.
Args:
imr_index: an integer specified the IMR index.
alg_id: an integer specified the hash algorithm.
rtmr: bytes of RTMR data for comparison.
"""
assert 0 <= imr_index < TdxRTMR.RTMR_COUNT
assert rtmr is not None
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id])
assert imr is not None
digest_obj = imr.digest(alg_id)
assert digest_obj is not None
digest_alg_id = digest_obj.alg.alg_id
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
digest_hash = digest_obj.hash
assert digest_hash is not None
assert digest_hash == rtmr, \
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}"

def tdx_check_measurement_imrs():
"""Test measurement result.
The test is done by compare the measurement register against the value
derived by replay eventlog.
"""
alg = CCTrustedVmSdk.inst().get_default_algorithms()
rtmrs = _replay_eventlog()
_check_imr(0, alg.alg_id, rtmrs[0])
_check_imr(1, alg.alg_id, rtmrs[1])
_check_imr(2, alg.alg_id, rtmrs[2])
_check_imr(3, alg.alg_id, rtmrs[3])

def tdx_check_quote_rtmrs():
"""Test quote result.
The test is done by compare the RTMRs in quote body against the value
derived by replay eventlog.
"""
quote = CCTrustedVmSdk.inst().get_quote()
assert quote is not None
assert isinstance(quote, TdxQuote)
body = quote.body
assert body is not None
assert isinstance(body, TdxQuoteBody)
rtmrs = _replay_eventlog()
assert body.rtmr0 == rtmrs[0], \
"RTMR0 doesn't equal the replay from event log!"
assert body.rtmr1 == rtmrs[1], \
"RTMR1 doesn't equal the replay from event log!"
assert body.rtmr2 == rtmrs[2], \
"RTMR2 doesn't equal the replay from event log!"
assert body.rtmr3 == rtmrs[3], \
"RTMR3 doesn't equal the replay from event log!"
104 changes: 52 additions & 52 deletions vmsdk/python/tests/test_sdk.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,55 @@
"""Containing unit test cases for sdk class"""

import pytest
from cctrusted_vm import CCTrustedVmSdk

class TestCCTrustedVmSdk():
"""Unit tests for CCTrustedVmSdk class."""

def test_get_default_algorithms(self):
"""Test get_default_algorithms() function."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None

def test_get_measurement_count(self):
"""Test get_measurement_count() function."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count is not None

def test_get_measurement_with_invalid_input(self):
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = CCTrustedVmSdk.inst().get_measurement([0, None])
assert measurement is not None

def test_get_measurement_with_valid_input(self):
"""Test get_measurement() function with valid input."""
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(count):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
assert digest_obj is not None

def test_get_eventlog_with_invalid_input(self):
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=0)

def test_get_eventlog_with_valid_input(self):
"""Test get_eventlog() funtion with valid input."""
event_logs = CCTrustedVmSdk.inst().get_eventlog()
assert event_logs is not None

def test_get_quote_with_valid_input(self):
"""Test get_quote() function with valid input."""
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
assert quote is not None

def test_get_default_algorithms(vm_sdk, default_alg_id):
"""Test get_default_algorithms() function."""
algo = vm_sdk.get_default_algorithms()
assert algo is not None
assert algo.alg_id == default_alg_id

def test_get_measurement_count(vm_sdk, measurement_count):
"""Test get_measurement_count() function."""
count = vm_sdk.get_measurement_count()
assert count is not None
assert count == measurement_count

def test_get_measurement_with_invalid_input(vm_sdk):
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = vm_sdk.get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = vm_sdk.get_measurement([0, None])
assert measurement is not None

def test_get_measurement_with_valid_input(vm_sdk, check_measurement):
"""Test get_measurement() function with valid input."""
count = vm_sdk.get_measurement_count()
for index in range(count):
alg = vm_sdk.get_default_algorithms()
digest_obj = vm_sdk.get_measurement([index, alg.alg_id])
assert digest_obj is not None
check_measurement()

def test_get_eventlog_with_invalid_input(vm_sdk):
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
vm_sdk.get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
vm_sdk.get_eventlog(start=0)

def test_get_eventlog_with_valid_input(vm_sdk):
"""Test get_eventlog() funtion with valid input."""
event_logs = vm_sdk.get_eventlog()
assert event_logs is not None

def test_get_quote_with_valid_input(vm_sdk, check_quote):
"""Test get_quote() function with valid input."""
quote = vm_sdk.get_quote(None, None, None)
assert quote is not None
check_quote()

0 comments on commit e3dcb82

Please sign in to comment.