Skip to content

Commit

Permalink
Add SparseTensor input validation to SparseCore conversion op.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677054398
  • Loading branch information
mrry authored and tensorflower-gardener committed Sep 21, 2024
1 parent 3984cff commit d56766c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
31 changes: 20 additions & 11 deletions tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values,
// tensor.
} else if (indices_or_row_splits.dims() == 2 &&
indices_or_row_splits.NumElements() >= 0) {
// TODO(pineapplejuice233): Add checking logic for sparse tensor input.
// NOTE(mrry): Checking logic for SparseTensor inputs is in
// `ComputeRowIdsBeforePadding()`, to avoid an extra traversal of the
// indices matrix.
} else if (indices_or_row_splits.dims() == 1 &&
indices_or_row_splits.NumElements() > 0) {
// Ragged tensor.
Expand All @@ -114,6 +116,7 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values,

Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits,
const int32 total_id_count,
const int32 sample_count,
int32* row_ids_before_padding) {
// The only difference between dense tensor, sparse tensor and ragged tensor
// is the row ids output.
Expand All @@ -140,7 +143,14 @@ Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits,
if (current_row_id < previous_row_id) {
return absl::InvalidArgumentError(
"Invalid indices_or_row_splits input, indices of SparseTensor need "
"to be sorted in ascending order.");
"to be sorted in ascending (non-decreasing) order.");
}
if (current_row_id >= sample_count) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid indices_or_row_splits input, indices of SparseTensor "
"contained a row_id ",
current_row_id, " that was >= the sample count (", sample_count,
")."));
}
*(row_ids_before_padding + i) = current_row_id;
previous_row_id = current_row_id;
Expand Down Expand Up @@ -309,9 +319,9 @@ class ConvertToCooTensorOp : public OpKernel {

auto row_ids_before_dedup = std::make_unique<int32[]>(total_id_count);

OP_REQUIRES_OK(
ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, total_id_count,
row_ids_before_dedup.get()));
OP_REQUIRES_OK(ctx, ComputeRowIdsBeforePadding(
*indices_or_row_splits, total_id_count,
sample_count_, row_ids_before_dedup.get()));

// Compute the rescaled gains for non-sum combiners.
std::optional<std::vector<float>> gains_rescale =
Expand Down Expand Up @@ -520,9 +530,8 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) {
"The number of minibatches per sparse core is ", num_minibatch_per_sc,
". But the max minibatches per sparse core is set to be ",
max_minibatches_per_sc_, " which is smaller.")));
VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: "
<< "program_key = '" << program_key << "'"
<< ", table_name = '" << table_name_ << "'"
VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: " << "program_key = '"
<< program_key << "'" << ", table_name = '" << table_name_ << "'"
<< ", max_ids = " << max_ids_per_partition
<< ", max_uniques = " << max_unique_ids_per_partition
<< ", num_minibatch_per_sc = " << num_minibatch_per_sc;
Expand Down Expand Up @@ -1213,9 +1222,9 @@ void ConvertToListOfSparseCoreCooTensorsOp::Compute(OpKernelContext* ctx) {
auto row_ids_before_dedup = std::unique_ptr<int32[]>(
new std::remove_extent_t<int32[]>[total_id_count]);

OP_REQUIRES_OK(
ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, total_id_count,
row_ids_before_dedup.get()));
OP_REQUIRES_OK(ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits,
total_id_count, sample_count_,
row_ids_before_dedup.get()));

// Compute the rescaled gains for non-sum combiners.
std::optional<std::vector<float>> gains_rescale =
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values,

// Compute the row id list before padding.
Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits,
int32 total_id_count,
int32 total_id_count, int32 sample_count,
int32* row_ids_before_padding);

class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel {
Expand Down

0 comments on commit d56766c

Please sign in to comment.