From 622995866825f064ced4a46a88f15eebf3743dea Mon Sep 17 00:00:00 2001 From: Sam Doran Date: Fri, 3 May 2024 17:14:51 -0400 Subject: [PATCH] Add data transfer direction as a parameter --- nise/generators/aws/data_transfer_generator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/nise/generators/aws/data_transfer_generator.py b/nise/generators/aws/data_transfer_generator.py index b242b649..a65fa207 100644 --- a/nise/generators/aws/data_transfer_generator.py +++ b/nise/generators/aws/data_transfer_generator.py @@ -31,6 +31,7 @@ class DataTransferGenerator(AWSGenerator): ("{region1}-DataTransfer-Regional-Bytes", "PublicIP-{direction}", ""), ("{region1}-DataTransfer-Regional-Bytes", "InterZone-{direction}", ""), ) + DATA_TRANSFER_DIRECTIONS = ("in", "out") def __init__(self, start_date, end_date, currency, payer_account, usage_accounts, attributes=None, tag_cols=None): """Initialize the data transfer generator.""" @@ -45,16 +46,24 @@ def __init__(self, start_date, end_date, currency, payer_account, usage_accounts self._resource_id = self.attributes.get("resource_id") self._saving = float(self.attributes.get("saving", 0)) or None self._tags = self.attributes.get("tags", self._tags) + + @property + def direction(self): + if self._direction: + return self._direction.capitalize() + + return choice(self.DATA_TRANSFER_DIRECTIONS).capitalize() + def _get_data_transfer(self, rate): """Get data transfer info.""" location1, aws_region, _, storage_region1 = self._get_location() location2, _, _, storage_region2 = self._get_location() trans_desc, operation, trans_type = choice(self.DATA_TRANSFER) - trans_desc = trans_desc.format(storage_region1, storage_region2) trans_desc = trans_desc.format(region1=storage_region1, region2=storage_region2, direction=self.direction) operation = operation.format(direction=self.direction) trans_type = trans_type.format(direction=self.direction) description = f"${rate} per GB - {location1} data transfer to {location2}" + return trans_desc, operation, description, location1, location2, trans_type, aws_region def _get_product_sku(self):