Skip to content

Commit

Permalink
Update BB benchmarking pipeline to output tfrecords
Browse files Browse the repository at this point in the history
This patch updates the BB benchmarking pipeline to directly output
tfrecord files containing BasicBlockWithThroughputProtos. This makes it
a lot simpler to begin training on large datasets as no postprocessing
is needed to import the basic blocks from the BHive CSV format.

Pull Request: google#275
  • Loading branch information
boomanaiden154 committed Dec 31, 2024
1 parent 3fc43f7 commit 81ef300
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 20 deletions.
5 changes: 5 additions & 0 deletions gematria/datasets/pipelines/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ gematria_py_binary(
srcs = ["benchmark_bbs_lib.py"],
deps = [
":benchmark_cpu_scheduler",
"//gematria/datasets/python:bhive_importer",
"//gematria/datasets/python:exegesis_benchmark",
"//gematria/llvm/python:canonicalizer",
"//gematria/llvm/python:llvm_architecture_support",
"//gematria/proto:execution_annotation_py_pb2",
"//gematria/proto:throughput_py_pb2",
],
)

Expand All @@ -82,6 +86,7 @@ gematria_py_test(
":benchmark_cpu_scheduler",
"//gematria/io/python:tfrecord",
"//gematria/proto:execution_annotation_py_pb2",
"//gematria/proto:throughput_py_pb2",
],
)

Expand Down
30 changes: 22 additions & 8 deletions gematria/datasets/pipelines/benchmark_bbs_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from gematria.proto import execution_annotation_pb2
from gematria.datasets.python import exegesis_benchmark
from gematria.datasets.pipelines import benchmark_cpu_scheduler
from gematria.proto import throughput_pb2
from gematria.llvm.python import canonicalizer
from gematria.llvm.python import llvm_architecture_support
from gematria.datasets.python import bhive_importer

_BEAM_METRIC_NAMESPACE_NAME = 'benchmark_bbs'

Expand Down Expand Up @@ -71,14 +75,21 @@ def process(
pass


class FormatBBsForOutput(beam.DoFn):
"""A Beam function for formatting hex/throughput values for output."""
class SerializeToProto(beam.DoFn):
"""A Beam function for formatting hex/throughput values to protos."""

def setup(self):
self._x86_llvm = llvm_architecture_support.LlvmArchitectureSupport.x86_64()
self._x86_canonicalizer = canonicalizer.Canonicalizer.x86_64(self._x86_llvm)
self._importer = bhive_importer.BHiveImporter(self._x86_canonicalizer)

def process(
self, block_hex_and_throughput: tuple[str, float]
) -> Iterable[str]:
) -> Iterable[throughput_pb2.BasicBlockWithThroughputProto]:
block_hex, throughput = block_hex_and_throughput
yield f'{block_hex},{throughput}'
yield self._importer.block_with_throughput_from_hex_and_throughput(
'pipeline', block_hex, throughput
)


def benchmark_bbs(
Expand All @@ -99,12 +110,15 @@ def pipeline(root: beam.Pipeline) -> None:
benchmarked_blocks = annotated_bbs_shuffled | 'Benchmarking' >> beam.ParDo(
BenchmarkBasicBlock(benchmark_scheduler_type)
)
formatted_output = benchmarked_blocks | 'Formatting' >> beam.ParDo(
FormatBBsForOutput()
block_protos = benchmarked_blocks | 'Serialize to protos' >> beam.ParDo(
SerializeToProto()
)

_ = formatted_output | 'Write To Text' >> beam.io.WriteToText(
output_file_pattern
_ = block_protos | 'Write serialized blocks' >> beam.io.WriteToTFRecord(
output_file_pattern,
coder=beam.coders.ProtoCoder(
throughput_pb2.BasicBlockWithThroughputProto().__class__
),
)

return pipeline
30 changes: 18 additions & 12 deletions gematria/datasets/pipelines/benchmark_bbs_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

from absl.testing import absltest
from apache_beam.testing import test_pipeline
from apache_beam.testing import util as beam_test

from gematria.datasets.pipelines import benchmark_bbs_lib
from gematria.proto import execution_annotation_pb2
from gematria.io.python import tfrecord
from gematria.datasets.pipelines import benchmark_cpu_scheduler
from gematria.proto import throughput_pb2

BLOCK_FOR_TESTING = execution_annotation_pb2.BlockWithExecutionAnnotations(
execution_annotations=execution_annotation_pb2.ExecutionAnnotations(
Expand Down Expand Up @@ -60,14 +60,14 @@ def test_benchmark_basic_block(self):
self.assertEqual(block_hex, '3b31')
self.assertLess(block_throughput, 10)

def test_format_bbs(self):
format_transform = benchmark_bbs_lib.FormatBBsForOutput()
def test_serialize_bbs_to_protos(self):
serialize_transform = benchmark_bbs_lib.SerializeToProto()
serialize_transform.setup()

benchmarked_block_data = ('3b31', 5)

output = list(format_transform.process(benchmarked_block_data))
output = list(serialize_transform.process(benchmarked_block_data))
self.assertLen(output, 1)
self.assertEqual(output[0], '3b31,5')

def test_benchmark_bbs(self):
test_tfrecord = self.create_tempfile()
Expand All @@ -85,13 +85,19 @@ def test_benchmark_bbs(self):
with test_pipeline.TestPipeline() as pipeline_under_test:
pipeline_constructor(pipeline_under_test)

with open(output_file_pattern + '-00000-of-00001') as output_txt_file:
output_lines = output_txt_file.readlines()
self.assertLen(output_lines, 1)

line_parts = output_lines[0].split(',')
self.assertEqual(line_parts[0], '3b31')
self.assertLess(float(line_parts[1]), 10)
throughputs = []
for block_with_throughput in tfrecord.read_protos(
[output_file_pattern + '-00000-of-00001'],
throughput_pb2.BasicBlockWithThroughputProto,
):
throughputs.append(
block_with_throughput.inverse_throughputs[
0
].inverse_throughput_cycles[0]
)

self.assertLen(throughputs, 1)
self.assertLess(throughputs[0], 10)


if __name__ == '__main__':
Expand Down

0 comments on commit 81ef300

Please sign in to comment.