Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change RunId to be a string inside the scheduler #3890

Merged
merged 10 commits into from
Sep 4, 2024
19 changes: 7 additions & 12 deletions internal/scheduler/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/types"
"github.com/google/uuid"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -114,8 +113,8 @@ func (srv *ExecutorApi) LeaseJobRuns(stream executorapi.ExecutorApi_LeaseJobRuns
if err := stream.Send(&executorapi.LeaseStreamMessage{
Event: &executorapi.LeaseStreamMessage_CancelRuns{
CancelRuns: &executorapi.CancelRuns{
JobRunIdsToCancel: slices.Map(runsToCancel, func(x uuid.UUID) *armadaevents.Uuid {
return armadaevents.ProtoUuidFromUuid(x)
JobRunIdsToCancel: slices.Map(runsToCancel, func(x string) *armadaevents.Uuid {
return armadaevents.MustProtoUuidFromUuidString(x)
}),
},
},
Expand Down Expand Up @@ -364,19 +363,15 @@ func (srv *ExecutorApi) executorFromLeaseRequest(ctx *armadacontext.Context, req
}

// runIdsFromLeaseRequest returns the ids of all runs in a lease request, including any not yet assigned to a node.
func runIdsFromLeaseRequest(req *executorapi.LeaseRequest) ([]uuid.UUID, error) {
runIds := make([]uuid.UUID, 0, 256)
func runIdsFromLeaseRequest(req *executorapi.LeaseRequest) ([]string, error) {
runIds := make([]string, 0, 256)
for _, node := range req.Nodes {
for runIdStr := range node.RunIdsByState {
if runId, err := uuid.Parse(runIdStr); err != nil {
return nil, errors.WithStack(err)
} else {
runIds = append(runIds, runId)
}
for runId := range node.RunIdsByState {
runIds = append(runIds, runId)
}
}
for _, runId := range req.UnassignedJobRunIds {
runIds = append(runIds, armadaevents.UuidFromProtoUuid(runId))
runIds = append(runIds, armadaevents.UuidFromProtoUuid(runId).String())
}
return runIds, nil
}
Expand Down
24 changes: 12 additions & 12 deletions internal/scheduler/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ var priorityClasses = map[string]types.PriorityClass{
func TestExecutorApi_LeaseJobRuns(t *testing.T) {
const maxJobsPerCall = uint(100)
testClock := clock.NewFakeClock(time.Now())
runId1 := uuid.New()
runId2 := uuid.New()
runId3 := uuid.New()
runId1 := uuid.NewString()
runId2 := uuid.NewString()
runId3 := uuid.NewString()
groups, compressedGroups := groups(t)
defaultRequest := &executorapi.LeaseRequest{
ExecutorId: "test-executor",
Expand All @@ -54,13 +54,13 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) {
{
Name: "test-node",
RunIdsByState: map[string]api.JobState{
runId1.String(): api.JobState_RUNNING,
runId2.String(): api.JobState_RUNNING,
runId1: api.JobState_RUNNING,
runId2: api.JobState_RUNNING,
},
NodeType: "node-type-1",
},
},
UnassignedJobRunIds: []*armadaevents.Uuid{armadaevents.ProtoUuidFromUuid(runId3)},
UnassignedJobRunIds: []*armadaevents.Uuid{armadaevents.MustProtoUuidFromUuidString(runId3)},
MaxJobsToLease: uint32(maxJobsPerCall),
}
defaultExpectedExecutor := &schedulerobjects.Executor{
Expand All @@ -72,7 +72,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) {
Name: "test-node",
Executor: "test-executor",
TotalResources: schedulerobjects.NewResourceList(0),
StateByJobRunId: map[string]schedulerobjects.JobRunState{runId1.String(): schedulerobjects.JobRunState_RUNNING, runId2.String(): schedulerobjects.JobRunState_RUNNING},
StateByJobRunId: map[string]schedulerobjects.JobRunState{runId1: schedulerobjects.JobRunState_RUNNING, runId2: schedulerobjects.JobRunState_RUNNING},
NonArmadaAllocatedResources: map[int32]schedulerobjects.ResourceList{},
AllocatableByPriorityAndResource: map[int32]schedulerobjects.ResourceList{
1000: {
Expand All @@ -88,7 +88,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) {
},
},
LastUpdateTime: testClock.Now().UTC(),
UnassignedJobRuns: []string{runId3.String()},
UnassignedJobRuns: []string{runId3},
}

submit, compressedSubmit := submitMsg(
Expand Down Expand Up @@ -184,20 +184,20 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) {

tests := map[string]struct {
request *executorapi.LeaseRequest
runsToCancel []uuid.UUID
runsToCancel []string
leases []*database.JobRunLease
expectedExecutor *schedulerobjects.Executor
expectedMsgs []*executorapi.LeaseStreamMessage
}{
"lease and cancel": {
request: defaultRequest,
runsToCancel: []uuid.UUID{runId2},
runsToCancel: []string{runId2},
leases: []*database.JobRunLease{defaultLease},
expectedExecutor: defaultExpectedExecutor,
expectedMsgs: []*executorapi.LeaseStreamMessage{
{
Event: &executorapi.LeaseStreamMessage_CancelRuns{CancelRuns: &executorapi.CancelRuns{
JobRunIdsToCancel: []*armadaevents.Uuid{armadaevents.ProtoUuidFromUuid(runId2)},
JobRunIdsToCancel: []*armadaevents.Uuid{armadaevents.MustProtoUuidFromUuidString(runId2)},
}},
},
{
Expand Down Expand Up @@ -304,7 +304,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) {
assert.Equal(t, tc.expectedExecutor, executor)
return nil
}).Times(1)
mockJobRepository.EXPECT().FindInactiveRuns(gomock.Any(), schedulermocks.SliceMatcher[uuid.UUID]{Expected: runIds}).Return(tc.runsToCancel, nil).Times(1)
mockJobRepository.EXPECT().FindInactiveRuns(gomock.Any(), schedulermocks.SliceMatcher{Expected: runIds}).Return(tc.runsToCancel, nil).Times(1)
mockJobRepository.EXPECT().FetchJobRunLeases(gomock.Any(), tc.request.ExecutorId, maxJobsPerCall, runIds).Return(tc.leases, nil).Times(1)

// capture all sent messages
Expand Down
24 changes: 12 additions & 12 deletions internal/scheduler/database/job_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type JobRepository interface {

// FetchJobRunErrors returns all armadaevents.JobRunErrors for the provided job run ids. The returned map is
// keyed by job run id. Any dbRuns which don't have errors wil be absent from the map.
FetchJobRunErrors(ctx *armadacontext.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error)
FetchJobRunErrors(ctx *armadacontext.Context, runIds []string) (map[string]*armadaevents.Error, error)

// CountReceivedPartitions returns a count of the number of partition messages present in the database corresponding
// to the provided groupId. This is used by the scheduler to determine if the database represents the state of
Expand All @@ -50,11 +50,11 @@ type JobRepository interface {

// FindInactiveRuns returns a slice containing all dbRuns that the scheduler does not currently consider active
// Runs are inactive if they don't exist or if they have succeeded, failed or been cancelled
FindInactiveRuns(ctx *armadacontext.Context, runIds []uuid.UUID) ([]uuid.UUID, error)
FindInactiveRuns(ctx *armadacontext.Context, runIds []string) ([]string, error)

// FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run
// in excludedRunIds will be excluded
FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error)
FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []string) ([]*JobRunLease, error)
}

// PostgresJobRepository is an implementation of JobRepository that stores its state in postgres
Expand All @@ -74,14 +74,14 @@ func NewPostgresJobRepository(db *pgxpool.Pool, batchSize int32) *PostgresJobRep

// FetchJobRunErrors returns all armadaevents.JobRunErrors for the provided job run ids. The returned map is
// keyed by job run id. Any dbRuns which don't have errors wil be absent from the map.
func (r *PostgresJobRepository) FetchJobRunErrors(ctx *armadacontext.Context, runIds []uuid.UUID) (map[uuid.UUID]*armadaevents.Error, error) {
func (r *PostgresJobRepository) FetchJobRunErrors(ctx *armadacontext.Context, runIds []string) (map[string]*armadaevents.Error, error) {
if len(runIds) == 0 {
return map[uuid.UUID]*armadaevents.Error{}, nil
return map[string]*armadaevents.Error{}, nil
}

chunks := armadaslices.PartitionToMaxLen(runIds, int(r.batchSize))

errorsByRunId := make(map[uuid.UUID]*armadaevents.Error, len(runIds))
errorsByRunId := make(map[string]*armadaevents.Error, len(runIds))
decompressor := compress.NewZlibDecompressor()

err := pgx.BeginTxFunc(ctx, r.db, pgx.TxOptions{
Expand Down Expand Up @@ -116,7 +116,7 @@ func (r *PostgresJobRepository) FetchJobRunErrors(ctx *armadacontext.Context, ru
if err != nil {
return errors.WithStack(err)
}
errorsByRunId[runId] = jobError
errorsByRunId[runId.String()] = jobError
}
}
return nil
Expand Down Expand Up @@ -192,8 +192,8 @@ func (r *PostgresJobRepository) FetchJobUpdates(ctx *armadacontext.Context, jobS

// FindInactiveRuns returns a slice containing all dbRuns that the scheduler does not currently consider active
// Runs are inactive if they don't exist or if they have succeeded, failed or been cancelled
func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, runIds []uuid.UUID) ([]uuid.UUID, error) {
var inactiveRuns []uuid.UUID
func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, runIds []string) ([]string, error) {
var inactiveRuns []string
err := pgx.BeginTxFunc(ctx, r.db, pgx.TxOptions{
IsoLevel: pgx.ReadCommitted,
AccessMode: pgx.ReadWrite,
Expand Down Expand Up @@ -224,7 +224,7 @@ func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, run
if err != nil {
return errors.WithStack(err)
}
inactiveRuns = append(inactiveRuns, runId)
inactiveRuns = append(inactiveRuns, runId.String())
}
return nil
})
Expand All @@ -233,7 +233,7 @@ func (r *PostgresJobRepository) FindInactiveRuns(ctx *armadacontext.Context, run

// FetchJobRunLeases fetches new job runs for a given executor. A maximum of maxResults rows will be returned, while run
// in excludedRunIds will be excluded
func (r *PostgresJobRepository) FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []uuid.UUID) ([]*JobRunLease, error) {
func (r *PostgresJobRepository) FetchJobRunLeases(ctx *armadacontext.Context, executor string, maxResults uint, excludedRunIds []string) ([]*JobRunLease, error) {
if maxResults == 0 {
return []*JobRunLease{}, nil
}
Expand Down Expand Up @@ -312,7 +312,7 @@ func fetch[T hasSerial](from int64, batchSize int32, fetchBatch func(int64) ([]T
}

// Insert all run ids into a tmp table. The name of the table is returned
func insertRunIdsToTmpTable(ctx *armadacontext.Context, tx pgx.Tx, runIds []uuid.UUID) (string, error) {
func insertRunIdsToTmpTable(ctx *armadacontext.Context, tx pgx.Tx, runIds []string) (string, error) {
tmpTable := database.UniqueTableName("job_runs")

_, err := tx.Exec(ctx, fmt.Sprintf("CREATE TEMPORARY TABLE %s (run_id uuid) ON COMMIT DROP", tmpTable))
Expand Down
Loading