diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index a572de0738..a97b8c6705 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1556,7 +1556,7 @@ def get_table_name_for_logging(table_names: Optional[list[str]]) -> str: table_name_set = set(table_names) if len(table_name_set) == 1: return next(iter(table_name_set)) - return f"<{len(table_name_set)} tables>" + return f"<{len(table_name_set)} tables>: {table_name_set}" @staticmethod def get_prefetch_passes( diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py index 3dd1bc2cd4..b6864a3ac1 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py @@ -178,17 +178,17 @@ def test_get_table_name_for_logging(self) -> None: SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging( ["t1", "t2"] ), - "<2 tables>", + "<2 tables>: {'t1', 't2'}", ) self.assertEqual( SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging( ["t1", "t2", "t1"] ), - "<2 tables>", + "<2 tables>: {'t1', 't2'}", ) self.assertEqual( SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging([]), - "<0 tables>", + "<0 tables>: set()", ) @unittest.skipIf(*gpu_unavailable)