Skip to content

Commit

Permalink
Add function to get proto from hex and throughput value
Browse files Browse the repository at this point in the history
This patch adds a new function to bhive_importer to directly create a
BasicBlockWithThroughputProto from a basic block in hex format and its
associated throughput value. No unittests are added on the C++ side as
this is a refactoring. A unit test is added on the python side to test
the bindings for the new function. This is intended to be used in the
benchmarking pipeline where formatting everything as a CSV before
passing it into the bhive importer makes little sense.

Pull Request: google#274
  • Loading branch information
boomanaiden154 committed Dec 30, 2024
1 parent ed01cb2 commit c47bd63
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 12 deletions.
34 changes: 22 additions & 12 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ BHiveImporter::BasicBlockProtoFromMachineCodeHex(
base_address);
}

absl::StatusOr<BasicBlockWithThroughputProto>
BHiveImporter::BlockWithThroughputFromHexAndThroughput(
std::string_view source_name, std::string_view bb_hex, double throughput,
double throughput_scaling, uint64_t base_address) {
BasicBlockWithThroughputProto proto;
absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMachineCodeHex(bb_hex, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();
*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

ThroughputWithSourceProto& throughput_proto =
*proto.add_inverse_throughputs();
throughput_proto.set_source(source_name);
throughput_proto.add_inverse_throughput_cycles(throughput *
throughput_scaling);

return proto;
}

absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseBHiveCsvLine(
std::string_view source_name, std::string_view line,
size_t machine_code_hex_column_index, size_t throughput_column_index,
Expand All @@ -144,24 +163,15 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseBHiveCsvLine(
columns[machine_code_hex_column_index];
const std::string_view throughput_str = columns[throughput_column_index];

BasicBlockWithThroughputProto proto;
absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMachineCodeHex(machine_code_hex, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();
*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

double throughput_cycles = 0.0;
if (!absl::SimpleAtod(throughput_str, &throughput_cycles)) {
return absl::InvalidArgumentError(
absl::StrCat("Could not parse throughput value ", throughput_str));
}

ThroughputWithSourceProto& throughput = *proto.add_inverse_throughputs();
throughput.set_source(source_name);
throughput.add_inverse_throughput_cycles(throughput_cycles *
throughput_scaling);

return proto;
return BlockWithThroughputFromHexAndThroughput(
source_name, machine_code_hex, throughput_cycles, throughput_scaling,
base_address);
}

} // namespace gematria
9 changes: 9 additions & 0 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ class BHiveImporter {
absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMachineCodeHex(
std::string_view machine_code_hex, uint64_t base_address = 0);

// Parses a basic block with throughput information directly from the hex
// string representing the assembly and the throughput value as a double.
absl::StatusOr<BasicBlockWithThroughputProto>
BlockWithThroughputFromHexAndThroughput(std::string_view source_name,
std::string_view bb_hex,
double throughput,
double throughput_scaling = 1.0,
uint64_t base_address = 0);

// Parses a basic block with throughput from one BHive CSV line. Expects that
// the line has the format "{machine_code},{throughput}" where {machine_code}
// is the machine code of the basic block in the hex format accepted by
Expand Down
31 changes: 31 additions & 0 deletions gematria/datasets/python/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,37 @@ PYBIND11_MODULE(bhive_importer, m) {
// TODO(ondrasej): Raise ValueError when `machine_code` does not have
// the right format.
)
.def( //
"block_with_throughput_from_hex_and_throughput",
&BHiveImporter::BlockWithThroughputFromHexAndThroughput,
py::arg("source_name"), py::arg("bb_hex"), py::arg("throughput"),
py::arg("throughput_scaling") = 1.0,
py::arg("base_address") = uint64_t{0},
R"(Creates a BasicBlockWithThroughProto from a hex bb and throughput.
Creates a proto containing the basic block and associated throughput
information directly from the basic block in hex format and the
throughput as a floating point number.
Args:
source_name: The name of the throughput source used in the output
proto.
bb_hex: The basic block as a hex string.
throughput: The throughput of the basic block as a floating point
value.
throughput_scaling: An optional scaling factor applied to
{throughput}.
base_address: The address of the first instruction of the basic
block.
Returns:
A BasicBlockWithThroughputProto that contains the basic block
extracted from {bb_hex} with throughput information from
{throughput}.
Raises:
StatusNotOk: If parsing the basic block hex fails.
)")
.def( //
"basic_block_with_throughput_proto_from_csv_line",
&BHiveImporter::ParseBHiveCsvLine, py::arg("source_name"),
Expand Down
24 changes: 24 additions & 0 deletions gematria/datasets/python/bhive_importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,30 @@ def test_x86_nonstandard_columns(self):
),
)

def test_x86_block_from_hex_and_throughput(self):
source_name = "test: made-up"
importer = bhive_importer.BHiveImporter(self._x86_canonicalizer)
block_proto = importer.block_with_throughput_from_hex_and_throughput(
source_name,
"4829d38b44246c8b54246848c1fb034829d04839c3",
10,
throughput_scaling=2.0,
base_address=600,
)

self.assertEqual(
block_proto,
throughput_pb2.BasicBlockWithThroughputProto(
basic_block=_EXPECTED_BASIC_BLOCK_PROTO,
inverse_throughputs=(
throughput_pb2.ThroughputWithSourceProto(
source=source_name,
inverse_throughput_cycles=[20.0],
),
),
),
)


if __name__ == "__main__":
absltest.main()

0 comments on commit c47bd63

Please sign in to comment.