diff --git a/gematria/datasets/bhive_importer.cc b/gematria/datasets/bhive_importer.cc index 4ff0840e..b9d99f2d 100644 --- a/gematria/datasets/bhive_importer.cc +++ b/gematria/datasets/bhive_importer.cc @@ -121,6 +121,25 @@ BHiveImporter::BasicBlockProtoFromMachineCodeHex( base_address); } +absl::StatusOr +BHiveImporter::BlockWithThroughputFromHexAndThroughput( + std::string_view source_name, std::string_view bb_hex, double throughput, + double throughput_scaling, uint64_t base_address) { + BasicBlockWithThroughputProto proto; + absl::StatusOr block_proto_or_status = + BasicBlockProtoFromMachineCodeHex(bb_hex, base_address); + if (!block_proto_or_status.ok()) return block_proto_or_status.status(); + *proto.mutable_basic_block() = std::move(block_proto_or_status).value(); + + ThroughputWithSourceProto& throughput_proto = + *proto.add_inverse_throughputs(); + throughput_proto.set_source(source_name); + throughput_proto.add_inverse_throughput_cycles(throughput * + throughput_scaling); + + return proto; +} + absl::StatusOr BHiveImporter::ParseBHiveCsvLine( std::string_view source_name, std::string_view line, size_t machine_code_hex_column_index, size_t throughput_column_index, @@ -144,24 +163,15 @@ absl::StatusOr BHiveImporter::ParseBHiveCsvLine( columns[machine_code_hex_column_index]; const std::string_view throughput_str = columns[throughput_column_index]; - BasicBlockWithThroughputProto proto; - absl::StatusOr block_proto_or_status = - BasicBlockProtoFromMachineCodeHex(machine_code_hex, base_address); - if (!block_proto_or_status.ok()) return block_proto_or_status.status(); - *proto.mutable_basic_block() = std::move(block_proto_or_status).value(); - double throughput_cycles = 0.0; if (!absl::SimpleAtod(throughput_str, &throughput_cycles)) { return absl::InvalidArgumentError( absl::StrCat("Could not parse throughput value ", throughput_str)); } - ThroughputWithSourceProto& throughput = *proto.add_inverse_throughputs(); - throughput.set_source(source_name); - throughput.add_inverse_throughput_cycles(throughput_cycles * - throughput_scaling); - - return proto; + return BlockWithThroughputFromHexAndThroughput( + source_name, machine_code_hex, throughput_cycles, throughput_scaling, + base_address); } } // namespace gematria diff --git a/gematria/datasets/bhive_importer.h b/gematria/datasets/bhive_importer.h index 09d68160..3136922a 100644 --- a/gematria/datasets/bhive_importer.h +++ b/gematria/datasets/bhive_importer.h @@ -77,6 +77,15 @@ class BHiveImporter { absl::StatusOr BasicBlockProtoFromMachineCodeHex( std::string_view machine_code_hex, uint64_t base_address = 0); + // Parses a basic block with throughput information directly from the hex + // string representing the assembly and the throughput value as a double. + absl::StatusOr + BlockWithThroughputFromHexAndThroughput(std::string_view source_name, + std::string_view bb_hex, + double throughput, + double throughput_scaling = 1.0, + uint64_t base_address = 0); + // Parses a basic block with throughput from one BHive CSV line. Expects that // the line has the format "{machine_code},{throughput}" where {machine_code} // is the machine code of the basic block in the hex format accepted by diff --git a/gematria/datasets/python/bhive_importer.cc b/gematria/datasets/python/bhive_importer.cc index 7fc75d32..cd08c252 100644 --- a/gematria/datasets/python/bhive_importer.cc +++ b/gematria/datasets/python/bhive_importer.cc @@ -100,6 +100,37 @@ PYBIND11_MODULE(bhive_importer, m) { // TODO(ondrasej): Raise ValueError when `machine_code` does not have // the right format. ) + .def( // + "block_with_throughput_from_hex_and_throughput", + &BHiveImporter::BlockWithThroughputFromHexAndThroughput, + py::arg("source_name"), py::arg("bb_hex"), py::arg("throughput"), + py::arg("throughput_scaling") = 1.0, + py::arg("base_address") = uint64_t{0}, + R"(Creates a BasicBlockWithThroughProto from a hex bb and throughput. + + Creates a proto containing the basic block and associated throughput + information directly from the basic block in hex format and the + throughput as a floating point number. + + Args: + source_name: The name of the throughput source used in the output + proto. + bb_hex: The basic block as a hex string. + throughput: The throughput of the basic block as a floating point + value. + throughput_scaling: An optional scaling factor applied to + {throughput}. + base_address: The address of the first instruction of the basic + block. + + Returns: + A BasicBlockWithThroughputProto that contains the basic block + extracted from {bb_hex} with throughput information from + {throughput}. + + Raises: + StatusNotOk: If parsing the basic block hex fails. + )") .def( // "basic_block_with_throughput_proto_from_csv_line", &BHiveImporter::ParseBHiveCsvLine, py::arg("source_name"), diff --git a/gematria/datasets/python/bhive_importer_test.py b/gematria/datasets/python/bhive_importer_test.py index 4fd38395..93ce21c6 100644 --- a/gematria/datasets/python/bhive_importer_test.py +++ b/gematria/datasets/python/bhive_importer_test.py @@ -246,6 +246,30 @@ def test_x86_nonstandard_columns(self): ), ) + def test_x86_block_from_hex_and_throughput(self): + source_name = "test: made-up" + importer = bhive_importer.BHiveImporter(self._x86_canonicalizer) + block_proto = importer.block_with_throughput_from_hex_and_throughput( + source_name, + "4829d38b44246c8b54246848c1fb034829d04839c3", + 10, + throughput_scaling=2.0, + base_address=600, + ) + + self.assertEqual( + block_proto, + throughput_pb2.BasicBlockWithThroughputProto( + basic_block=_EXPECTED_BASIC_BLOCK_PROTO, + inverse_throughputs=( + throughput_pb2.ThroughputWithSourceProto( + source=source_name, + inverse_throughput_cycles=[20.0], + ), + ), + ), + ) + if __name__ == "__main__": absltest.main()