Skip to content

Commit

Permalink
reapi: favor platform set in Action over Command (#7661)
Browse files Browse the repository at this point in the history
In REAPI v2.2, the usage of platform in Command was deprecated over in
Action. Although most of the older REAPI clients, such as Bazel, are
setting it in both, newer clients such as Buck2 would only set it in
Action message.

This change replace most of existing command.GetPlatform() calls with
action.GetPlatform() while using the command value as fallback.

This was heavily inspired by
#7550, but with a few
more changes.

I tried a simpler approach where we overwrite command.platform with
action.platform, but because we read the original protos by digest in
ExecutionServer.markTaskComplete, this doesn't work.

Closes buildbuddy-io/buildbuddy-internal#3863
  • Loading branch information
vanja-p authored Oct 7, 2024
1 parent 5cc4539 commit 4394142
Show file tree
Hide file tree
Showing 16 changed files with 106 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ func NewContainer(ctx context.Context, env environment.Env, task *repb.Execution
c.vmIdx = opts.ForceVMIdx
}

c.supportsRemoteSnapshots = *snaputil.EnableRemoteSnapshotSharing && (platform.IsCICommand(task.GetCommand()) || *forceRemoteSnapshotting)
c.supportsRemoteSnapshots = *snaputil.EnableRemoteSnapshotSharing && (platform.IsCICommand(task.GetCommand(), platform.GetProto(task.GetAction(), task.GetCommand())) || *forceRemoteSnapshotting)

if opts.OverrideSnapshotKey == nil {
c.vmConfig.DebugMode = *debugTerminal
Expand All @@ -640,7 +640,7 @@ func NewContainer(ctx context.Context, env environment.Env, task *repb.Execution
// If recycling is enabled and a snapshot exists, then when calling
// Create(), load the snapshot instead of creating a new VM.

recyclingEnabled := platform.IsTrue(platform.FindValue(task.GetCommand().GetPlatform(), platform.RecycleRunnerPropertyName))
recyclingEnabled := platform.IsTrue(platform.FindValue(platform.GetProto(task.GetAction(), task.GetCommand()), platform.RecycleRunnerPropertyName))
if recyclingEnabled && *snaputil.EnableLocalSnapshotSharing {
snap, err := loader.GetSnapshot(ctx, c.snapshotKeySet, c.supportsRemoteSnapshots)
c.createFromSnapshot = (err == nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,24 +513,20 @@ func (s *ExecutionServer) dispatch(ctx context.Context, req *repb.ExecuteRequest
if sizer == nil {
return "", nil, status.FailedPreconditionError("No task sizer configured")
}
adInstanceDigest := digest.NewResourceName(req.GetActionDigest(), req.GetInstanceName(), rspb.CacheType_CAS, req.GetDigestFunction())
action := &repb.Action{}
if err := cachetools.ReadProtoFromCAS(ctx, s.cache, adInstanceDigest, action); err != nil {
log.CtxWarningf(ctx, "Error fetching action: %s", err.Error())
return "", nil, err
}
cmdInstanceDigest := digest.NewResourceName(action.GetCommandDigest(), req.GetInstanceName(), rspb.CacheType_CAS, req.GetDigestFunction())
command := &repb.Command{}
if err := cachetools.ReadProtoFromCAS(ctx, s.cache, cmdInstanceDigest, command); err != nil {
log.CtxWarningf(ctx, "Error fetching command: %s", err.Error())
return "", nil, err
}

invocationID := bazel_request.GetInvocationID(ctx)
rmd := bazel_request.GetRequestMetadata(ctx)
if invocationID == "" {
log.CtxInfof(ctx, "Execution %q is missing invocation ID metadata. Request metadata: %+v", executionID, rmd)
}

adInstanceDigest := digest.NewResourceName(req.GetActionDigest(), req.GetInstanceName(), rspb.CacheType_CAS, req.GetDigestFunction())
action, command, err := s.fetchActionAndCommand(ctx, adInstanceDigest)
if err != nil {
return "", nil, err
}
if action.GetPlatform() == nil && command.GetPlatform() != nil {
log.CtxInfof(ctx, "Execution %q has a platform in the command, but not the action. Request metadata: %v", executionID, rmd)
}
// Drop ToolDetails from the request metadata. Executors will include this
// metadata in all app requests, but the ToolDetails identify the request as
// being from bazel, which is not desired.
Expand Down Expand Up @@ -588,7 +584,7 @@ func (s *ExecutionServer) dispatch(ctx context.Context, req *repb.ExecuteRequest
if err != nil {
return "", nil, err
}
envVars, err = gcplink.ExchangeRefreshTokenForAuthToken(ctx, envVars, platform.IsCICommand(command))
envVars, err = gcplink.ExchangeRefreshTokenForAuthToken(ctx, envVars, platform.IsCICommand(command, platform.GetProto(action, command)))
if err != nil {
return "", nil, err
}
Expand Down Expand Up @@ -1087,7 +1083,7 @@ func (s *ExecutionServer) markTaskComplete(ctx context.Context, taskID string, e
if err != nil {
return err
}
cmd, err := s.fetchCommandForTask(ctx, actionResourceName)
action, cmd, err := s.fetchActionAndCommand(ctx, actionResourceName)
if err != nil {
return err
}
Expand All @@ -1096,7 +1092,7 @@ func (s *ExecutionServer) markTaskComplete(ctx context.Context, taskID string, e
// Only update the router if a task was actually executed
if execErr == nil && router != nil && !executeResponse.GetCachedResult() {
executorHostID := executeResponse.GetResult().GetExecutionMetadata().GetWorker()
router.MarkComplete(ctx, cmd, actionResourceName.GetInstanceName(), executorHostID)
router.MarkComplete(ctx, action, cmd, actionResourceName.GetInstanceName(), executorHostID)
}

// Skip sizer and usage updates for teed work.
Expand Down Expand Up @@ -1160,18 +1156,20 @@ func (s *ExecutionServer) updateUsage(ctx context.Context, cmd *repb.Command, ex
return ut.Increment(ctx, labels, counts)
}

func (s *ExecutionServer) fetchCommandForTask(ctx context.Context, actionResourceName *digest.ResourceName) (*repb.Command, error) {
func (s *ExecutionServer) fetchActionAndCommand(ctx context.Context, actionResourceName *digest.ResourceName) (*repb.Action, *repb.Command, error) {
action := &repb.Action{}
if err := cachetools.ReadProtoFromCAS(ctx, s.cache, actionResourceName, action); err != nil {
return nil, err
log.CtxWarningf(ctx, "Error fetching action: %s", err.Error())
return nil, nil, err
}
cmdDigest := action.GetCommandDigest()
cmdInstanceNameDigest := digest.NewResourceName(cmdDigest, actionResourceName.GetInstanceName(), rspb.CacheType_CAS, actionResourceName.GetDigestFunction())
cmd := &repb.Command{}
if err := cachetools.ReadProtoFromCAS(ctx, s.cache, cmdInstanceNameDigest, cmd); err != nil {
return nil, err
log.CtxWarningf(ctx, "Error fetching command: %s", err.Error())
return nil, nil, err
}
return cmd, nil
return action, cmd, nil
}

func executionDuration(md *repb.ExecutedActionMetadata) (time.Duration, error) {
Expand Down
2 changes: 1 addition & 1 deletion enterprise/server/remote_execution/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func isTaskMisconfigured(err error) bool {

func isClientBazel(task *repb.ExecutionTask) bool {
// TODO(bduffany): Find a more reliable way to determine this.
return !platform.IsCICommand(task.GetCommand())
return !platform.IsCICommand(task.GetCommand(), platform.GetProto(task.GetAction(), task.GetCommand()))
}

func shouldRetry(task *repb.ExecutionTask, taskError error) bool {
Expand Down
18 changes: 14 additions & 4 deletions enterprise/server/remote_execution/platform/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func (p *ExecutorProperties) SupportsIsolation(c ContainerType) bool {
// executor-specific overrides applied via the ApplyOverrides function.
func ParseProperties(task *repb.ExecutionTask) (*Properties, error) {
m := map[string]string{}
for _, prop := range task.GetCommand().GetPlatform().GetProperties() {
for _, prop := range GetProto(task.GetAction(), task.GetCommand()).GetProperties() {
m[strings.ToLower(prop.GetName())] = strings.TrimSpace(prop.GetValue())
}
for _, prop := range task.GetPlatformOverrides().GetProperties() {
Expand Down Expand Up @@ -708,6 +708,16 @@ func findValue(platform *repb.Platform, name string) (value string, ok bool) {
return "", false
}

// GetProto returns the platform proto from the action if it's present.
// Otherwise it returns the platform from the command. This is the desired
// behaviour as of REAPI 2.2.
func GetProto(action *repb.Action, cmd *repb.Command) *repb.Platform {
if plat := action.GetPlatform(); plat != nil {
return plat
}
return cmd.GetPlatform()
}

// FindValue scans the platform properties for the given property name (ignoring
// case) and returns the value of that property if it exists, otherwise "".
func FindValue(platform *repb.Platform, name string) string {
Expand All @@ -724,7 +734,7 @@ func FindEffectiveValue(task *repb.ExecutionTask, name string) string {
if ok {
return override
}
return FindValue(task.GetCommand().GetPlatform(), name)
return FindValue(GetProto(task.GetAction(), task.GetCommand()), name)
}

// IsTrue returns whether the given platform property value is truthy.
Expand All @@ -743,11 +753,11 @@ func DefaultImage() string {
// IsCICommand returns whether the given command is either a BuildBuddy workflow
// or a GitHub Actions runner task. These commands are longer-running and may
// themselves invoke bazel.
func IsCICommand(cmd *repb.Command) bool {
func IsCICommand(cmd *repb.Command, platform *repb.Platform) bool {
if len(cmd.GetArguments()) > 0 && cmd.GetArguments()[0] == "./buildbuddy_ci_runner" {
return true
}
if FindValue(cmd.GetPlatform(), "github-actions-runner-labels") != "" {
if FindValue(platform, "github-actions-runner-labels") != "" {
return true
}
return false
Expand Down
5 changes: 3 additions & 2 deletions enterprise/server/remote_execution/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ func (r *taskRunner) DownloadInputs(ctx context.Context, ioStats *repb.IOStats)
if err != nil {
return err
}
if platform.IsCICommand(r.task.GetCommand()) && !ci_runner_util.CanInitFromCache(r.PlatformProperties.OS, r.PlatformProperties.Arch) {
if platform.IsCICommand(r.task.GetCommand(), platform.GetProto(r.task.GetAction(), r.task.GetCommand())) &&
!ci_runner_util.CanInitFromCache(r.PlatformProperties.OS, r.PlatformProperties.Arch) {
if err := r.Workspace.AddCIRunner(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -943,7 +944,7 @@ func (p *pool) Get(ctx context.Context, st *repb.ScheduledTask) (interfaces.Runn
key := &rnpb.RunnerKey{
GroupId: groupID,
InstanceName: task.GetExecuteRequest().GetInstanceName(),
Platform: task.GetCommand().GetPlatform(),
Platform: platform.GetProto(task.GetAction(), task.GetCommand()),
PersistentWorkerKey: persistentWorkerKey,
}

Expand Down
1 change: 1 addition & 0 deletions enterprise/server/remote_execution/snaploader/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ go_library(
deps = [
"//enterprise/server/remote_execution/container",
"//enterprise/server/remote_execution/copy_on_write",
"//enterprise/server/remote_execution/platform",
"//enterprise/server/remote_execution/snaputil",
"//proto:firecracker_go_proto",
"//proto:remote_execution_go_proto",
Expand Down
3 changes: 2 additions & 1 deletion enterprise/server/remote_execution/snaploader/snaploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/container"
"github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/copy_on_write"
"github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/platform"
"github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/snaputil"
"github.com/buildbuddy-io/buildbuddy/server/environment"
"github.com/buildbuddy-io/buildbuddy/server/metrics"
Expand Down Expand Up @@ -48,7 +49,7 @@ const (
// SnapshotKeySet returns the cache keys for potential snapshot matches,
// as well as the key that should be written to.
func (l *FileCacheLoader) SnapshotKeySet(ctx context.Context, task *repb.ExecutionTask, configurationHash, runnerID string) (*fcpb.SnapshotKeySet, error) {
pd, err := digest.ComputeForMessage(task.GetCommand().GetPlatform(), repb.DigestFunction_SHA256)
pd, err := digest.ComputeForMessage(platform.GetProto(task.GetAction(), task.GetCommand()), repb.DigestFunction_SHA256)
if err != nil {
return nil, status.WrapErrorf(err, "failed to compute platform hash")
}
Expand Down
2 changes: 1 addition & 1 deletion enterprise/server/remote_execution/workspace/workspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (ws *Workspace) UploadOutputs(ctx context.Context, cmd *repb.Command, execu
// runner should be removed and cannot affect any files in the
// workspace anymore, so it is safe to rename the outputs files in
// upperdir here rather than copying.
recyclingEnabled := platform.IsTrue(platform.FindValue(ws.task.GetCommand().GetPlatform(), platform.RecycleRunnerPropertyName))
recyclingEnabled := platform.IsTrue(platform.FindValue(platform.GetProto(ws.task.GetAction(), ws.task.GetCommand()), platform.RecycleRunnerPropertyName))
opts := overlayfs.ApplyOpts{AllowRename: !recyclingEnabled}
if err := ws.overlay.Apply(egCtx, opts); err != nil {
return status.WrapError(err, "apply overlay upperdir changes")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1728,7 +1728,7 @@ func (s *SchedulerServer) enqueueTaskReservations(ctx context.Context, enqueueRe

attempts := 0
var rankedNodes []interfaces.RankedExecutionNode
nonPreferredDelay := getNonPreferredSchedulingDelay(cmd)
nonPreferredDelay := getNonPreferredSchedulingDelay(platform.GetProto(task.GetAction(), cmd))
delayable := enqueueRequest.GetDelay() == nil
for len(successfulReservations) < probeCount {
// If the queue of ranked, candidate nodes is empty, refresh them.
Expand All @@ -1747,7 +1747,7 @@ func (s *SchedulerServer) enqueueTaskReservations(ctx context.Context, enqueueRe
if len(candidateNodes) == 0 {
return status.UnavailableErrorf("requested executor ID not found")
}
rankedNodes = s.taskRouter.RankNodes(ctx, cmd, remoteInstanceName, toNodeInterfaces(candidateNodes))
rankedNodes = s.taskRouter.RankNodes(ctx, task.GetAction(), cmd, remoteInstanceName, toNodeInterfaces(candidateNodes))
}

select {
Expand Down Expand Up @@ -1792,8 +1792,8 @@ func (s *SchedulerServer) enqueueTaskReservations(ctx context.Context, enqueueRe

// Returns the delay that should be applied to executions scheduled on
// non-preferred execution nodes.
func getNonPreferredSchedulingDelay(cmd *repb.Command) time.Duration {
delayProperty := platform.FindValue(cmd.GetPlatform(), platform.RunnerRecyclingMaxWaitPropertyName)
func getNonPreferredSchedulingDelay(plat *repb.Platform) time.Duration {
delayProperty := platform.FindValue(plat, platform.RunnerRecyclingMaxWaitPropertyName)
if delayProperty == "" {
return defaultSchedulingDelay
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (n fakeRankedNode) IsPreferred() bool {
return n.preferred
}

func (f *fakeTaskRouter) RankNodes(ctx context.Context, cmd *repb.Command, remoteInstanceName string, nodes []interfaces.ExecutionNode) []interfaces.RankedExecutionNode {
func (f *fakeTaskRouter) RankNodes(ctx context.Context, action *repb.Action, cmd *repb.Command, remoteInstanceName string, nodes []interfaces.ExecutionNode) []interfaces.RankedExecutionNode {
rankedNodes := make([]interfaces.RankedExecutionNode, len(nodes))
for i, node := range nodes {
preferred := false
Expand All @@ -77,7 +77,7 @@ func (f *fakeTaskRouter) RankNodes(ctx context.Context, cmd *repb.Command, remot
return rankedNodes
}

func (f *fakeTaskRouter) MarkComplete(ctx context.Context, cmd *repb.Command, remoteInstanceName, executorInstanceID string) {
func (f *fakeTaskRouter) MarkComplete(ctx context.Context, action *repb.Action, cmd *repb.Command, remoteInstanceName, executorInstanceID string) {
}

type schedulerOpts struct {
Expand Down
37 changes: 17 additions & 20 deletions enterprise/server/scheduling/task_router/task_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ func nonePreferred(nodes []interfaces.ExecutionNode) []interfaces.RankedExecutio

// RankNodes returns the input nodes ordered by their affinity to the given
// routing properties.
func (tr *taskRouter) RankNodes(ctx context.Context, cmd *repb.Command, remoteInstanceName string, nodes []interfaces.ExecutionNode) []interfaces.RankedExecutionNode {
func (tr *taskRouter) RankNodes(ctx context.Context, action *repb.Action, cmd *repb.Command, remoteInstanceName string, nodes []interfaces.ExecutionNode) []interfaces.RankedExecutionNode {
nodes = copyNodes(nodes)

rand.Shuffle(len(nodes), func(i, j int) {
nodes[i], nodes[j] = nodes[j], nodes[i]
})

params := getRoutingParams(ctx, tr.env, cmd, remoteInstanceName)
params := getRoutingParams(ctx, tr.env, action, cmd, remoteInstanceName)
strategy := tr.selectRouter(params)
if strategy == nil {
return nonePreferred(nodes)
Expand Down Expand Up @@ -179,8 +179,8 @@ func (tr *taskRouter) RankNodes(ctx context.Context, cmd *repb.Command, remoteIn
// MarkComplete updates the routing table after a task is completed, so that
// future tasks with those properties are more likely to be fulfilled by the
// given node.
func (tr *taskRouter) MarkComplete(ctx context.Context, cmd *repb.Command, remoteInstanceName, executorHostID string) {
params := getRoutingParams(ctx, tr.env, cmd, remoteInstanceName)
func (tr *taskRouter) MarkComplete(ctx context.Context, action *repb.Action, cmd *repb.Command, remoteInstanceName, executorHostID string) {
params := getRoutingParams(ctx, tr.env, action, cmd, remoteInstanceName)
strategy := tr.selectRouter(params)
if strategy == nil {
return
Expand Down Expand Up @@ -221,16 +221,21 @@ func (tr *taskRouter) MarkComplete(ctx context.Context, cmd *repb.Command, remot
// Contains the parameters required to make a routing decision.
type routingParams struct {
cmd *repb.Command
platform *repb.Platform
remoteInstanceName string
groupID string
}

func getRoutingParams(ctx context.Context, env environment.Env, cmd *repb.Command, remoteInstanceName string) routingParams {
func getRoutingParams(ctx context.Context, env environment.Env, action *repb.Action, cmd *repb.Command, remoteInstanceName string) routingParams {
groupID := interfaces.AuthAnonymousUser
if u, err := env.GetAuthenticator().AuthenticatedUser(ctx); err == nil {
groupID = u.GetGroupID()
}
return routingParams{cmd: cmd, remoteInstanceName: remoteInstanceName, groupID: groupID}
return routingParams{
cmd: cmd,
platform: platform.GetProto(action, cmd),
remoteInstanceName: remoteInstanceName,
groupID: groupID}
}

// Selects and returns a Router to use, or nil if none applies.
Expand Down Expand Up @@ -273,10 +278,10 @@ type Router interface {
type ciRunnerRouter struct{}

func (ciRunnerRouter) Applies(params routingParams) bool {
return platform.IsCICommand(params.cmd) && platform.IsTrue(platform.FindValue(params.cmd.GetPlatform(), platform.RecycleRunnerPropertyName))
return platform.IsCICommand(params.cmd, params.platform) && platform.IsTrue(platform.FindValue(params.platform, platform.RecycleRunnerPropertyName))
}

func (ciRunnerRouter) preferredNodeLimit(params routingParams) int {
func (ciRunnerRouter) preferredNodeLimit(_ routingParams) int {
return ciRunnerPreferredNodeLimit
}

Expand All @@ -288,11 +293,7 @@ func (ciRunnerRouter) routingKeys(params routingParams) ([]string, error) {
parts = append(parts, params.remoteInstanceName)
}

p := params.cmd.GetPlatform()
if p == nil {
p = &repb.Platform{}
}
b, err := proto.Marshal(p)
b, err := proto.Marshal(params.platform)
if err != nil {
return nil, status.InternalErrorf("failed to marshal Command: %s", err)
}
Expand All @@ -301,7 +302,7 @@ func (ciRunnerRouter) routingKeys(params routingParams) ([]string, error) {
// For workflow tasks, route using git branch name so that when re-running the
// workflow multiple times using the same branch, the runs are more likely
// to hit an executor with a warmer snapshot cache.
if platform.IsCICommand(params.cmd) {
if platform.IsCICommand(params.cmd, params.platform) {
envVarNames := []string{"GIT_BRANCH"}
if *defaultBranchRoutingEnabled {
envVarNames = append(envVarNames, "GIT_BASE_BRANCH", "GIT_REPO_DEFAULT_BRANCH")
Expand Down Expand Up @@ -347,7 +348,7 @@ func (affinityRouter) Applies(params routingParams) bool {
return *affinityRoutingEnabled && getFirstOutput(params.cmd) != ""
}

func (affinityRouter) preferredNodeLimit(params routingParams) int {
func (affinityRouter) preferredNodeLimit(_ routingParams) int {
return defaultPreferredNodeLimit
}

Expand All @@ -358,11 +359,7 @@ func (affinityRouter) routingKey(params routingParams) (string, error) {
parts = append(parts, params.remoteInstanceName)
}

platform := params.cmd.GetPlatform()
if platform == nil {
platform = &repb.Platform{}
}
b, err := proto.Marshal(platform)
b, err := proto.Marshal(params.platform)
if err != nil {
return "", status.InternalErrorf("failed to marshal Command: %s", err)
}
Expand Down
Loading

0 comments on commit 4394142

Please sign in to comment.