Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to get proto from hex and throughput value #274

Open
wants to merge 9 commits into
base: users/boomanaiden154/main.add-function-to-get-proto-from-hex-and-throughput-value
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ BHiveImporter::BasicBlockProtoFromMachineCodeHex(
base_address);
}

absl::StatusOr<BasicBlockWithThroughputProto>
BHiveImporter::BlockWithThroughputFromHexAndThroughput(
std::string_view source_name, std::string_view bb_hex, double throughput,
double throughput_scaling, uint64_t base_address) {
BasicBlockWithThroughputProto proto;
absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMachineCodeHex(bb_hex, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();
*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

ThroughputWithSourceProto& throughput_proto =
*proto.add_inverse_throughputs();
throughput_proto.set_source(source_name);
throughput_proto.add_inverse_throughput_cycles(throughput *
throughput_scaling);

return proto;
}

absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseBHiveCsvLine(
std::string_view source_name, std::string_view line,
size_t machine_code_hex_column_index, size_t throughput_column_index,
Expand All @@ -144,24 +163,15 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseBHiveCsvLine(
columns[machine_code_hex_column_index];
const std::string_view throughput_str = columns[throughput_column_index];

BasicBlockWithThroughputProto proto;
absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMachineCodeHex(machine_code_hex, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();
*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

double throughput_cycles = 0.0;
if (!absl::SimpleAtod(throughput_str, &throughput_cycles)) {
return absl::InvalidArgumentError(
absl::StrCat("Could not parse throughput value ", throughput_str));
}

ThroughputWithSourceProto& throughput = *proto.add_inverse_throughputs();
throughput.set_source(source_name);
throughput.add_inverse_throughput_cycles(throughput_cycles *
throughput_scaling);

return proto;
return BlockWithThroughputFromHexAndThroughput(
source_name, machine_code_hex, throughput_cycles, throughput_scaling,
base_address);
}

} // namespace gematria
9 changes: 9 additions & 0 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ class BHiveImporter {
absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMachineCodeHex(
std::string_view machine_code_hex, uint64_t base_address = 0);

// Parses a basic block with throughput information directly from the hex
// string representing the assembly and the throughput value as a double.
absl::StatusOr<BasicBlockWithThroughputProto>
BlockWithThroughputFromHexAndThroughput(std::string_view source_name,
std::string_view bb_hex,
double throughput,
double throughput_scaling = 1.0,
uint64_t base_address = 0);

// Parses a basic block with throughput from one BHive CSV line. Expects that
// the line has the format "{machine_code},{throughput}" where {machine_code}
// is the machine code of the basic block in the hex format accepted by
Expand Down
31 changes: 31 additions & 0 deletions gematria/datasets/python/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,37 @@ PYBIND11_MODULE(bhive_importer, m) {
// TODO(ondrasej): Raise ValueError when `machine_code` does not have
// the right format.
)
.def( //
"block_with_throughput_from_hex_and_throughput",
&BHiveImporter::BlockWithThroughputFromHexAndThroughput,
py::arg("source_name"), py::arg("bb_hex"), py::arg("throughput"),
py::arg("throughput_scaling") = 1.0,
py::arg("base_address") = uint64_t{0},
R"(Creates a BasicBlockWithThroughProto from a hex bb and throughput.

Creates a proto containing the basic block and associated throughput
information directly from the basic block in hex format and the
throughput as a floating point number.

Args:
source_name: The name of the throughput source used in the output
proto.
bb_hex: The basic block as a hex string.
throughput: The throughput of the basic block as a floating point
value.
throughput_scaling: An optional scaling factor applied to
{throughput}.
base_address: The address of the first instruction of the basic
block.

Returns:
A BasicBlockWithThroughputProto that contains the basic block
extracted from {bb_hex} with throughput information from
{throughput}.

Raises:
StatusNotOk: If parsing the basic block hex fails.
)")
.def( //
"basic_block_with_throughput_proto_from_csv_line",
&BHiveImporter::ParseBHiveCsvLine, py::arg("source_name"),
Expand Down
24 changes: 24 additions & 0 deletions gematria/datasets/python/bhive_importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,30 @@ def test_x86_nonstandard_columns(self):
),
)

def test_x86_block_from_hex_and_throughput(self):
source_name = "test: made-up"
importer = bhive_importer.BHiveImporter(self._x86_canonicalizer)
block_proto = importer.block_with_throughput_from_hex_and_throughput(
source_name,
"4829d38b44246c8b54246848c1fb034829d04839c3",
10,
throughput_scaling=2.0,
base_address=600,
)

self.assertEqual(
block_proto,
throughput_pb2.BasicBlockWithThroughputProto(
basic_block=_EXPECTED_BASIC_BLOCK_PROTO,
inverse_throughputs=(
throughput_pb2.ThroughputWithSourceProto(
source=source_name,
inverse_throughput_cycles=[20.0],
),
),
),
)


if __name__ == "__main__":
absltest.main()
Loading