Skip to content

Commit

Permalink
Refactor agent and node workload types
Browse files Browse the repository at this point in the history
  • Loading branch information
kthomas committed Jan 5, 2024
1 parent 447bf74 commit ab19f49
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 97 deletions.
55 changes: 40 additions & 15 deletions agent-api/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package agentapi

import (
"errors"
"io"
"time"
)
Expand All @@ -27,38 +28,62 @@ type ExecutionProviderParams struct {
Stderr io.Writer `json:"-"`
Stdout io.Writer `json:"-"`

TmpFilename string `json:"-"`
VmID string `json:"-"`
TmpFilename *string `json:"-"`
VmID string `json:"-"`
}

type WorkRequest struct {
WorkloadName string `json:"workload_name"`
Hash string `json:"hash"`
TotalBytes int32 `json:"total_bytes"`
Environment map[string]string `json:"environment"`
WorkloadType string `json:"workload_type,omitempty"`
Hash *string `json:"hash,omitempty"`
TotalBytes *int32 `json:"total_bytes,omitempty"`
WorkloadName *string `json:"workload_name,omitempty"`
WorkloadType *string `json:"workload_type,omitempty"`

Stderr io.Writer `json:"-"`
Stdout io.Writer `json:"-"`
TmpFilename string `json:"-"`
TmpFilename *string `json:"-"`

Errors []error `json:"errors,omitempty"`
}

func (w *WorkRequest) Validate() bool {
w.Errors = make([]error, 0)

if w.WorkloadName == nil {
w.Errors = append(w.Errors, errors.New("workload name is required"))
}

if w.Hash == nil {
w.Errors = append(w.Errors, errors.New("hash is required"))
}

if w.TotalBytes == nil {
w.Errors = append(w.Errors, errors.New("total bytes is required"))
}

if w.WorkloadType == nil {
w.Errors = append(w.Errors, errors.New("workload type is required"))
}

return len(w.Errors) == 0
}

type WorkResponse struct {
Accepted bool `json:"accepted"`
Message string `json:"message"`
Accepted bool `json:"accepted"`
Message *string `json:"message"`
}

type HandshakeRequest struct {
MachineId string `json:"machine_id"`
MachineId *string `json:"machine_id"`
StartTime time.Time `json:"start_time"`
Message string `json:"message,omitempty"`
Message *string `json:"message,omitempty"`
}

type MachineMetadata struct {
VmId string `json:"vmid"`
NodeNatsAddress string `json:"node_address"`
NodePort int `json:"node_port"`
Message string `json:"message"`
VmId *string `json:"vmid"`
NodeNatsAddress *string `json:"node_address"`
NodePort *int `json:"node_port"`
Message *string `json:"message"`
}

type LogEntry struct {
Expand Down
9 changes: 9 additions & 0 deletions agent-api/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package agentapi

// returns the given string or nil if empty
func StringOrNil(str string) *string {
if str == "" {
return nil
}
return &str
}
39 changes: 26 additions & 13 deletions nex-agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewAgent() (*Agent, error) {
return nil, err
}

nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", metadata.NodeNatsAddress, metadata.NodePort))
nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", *metadata.NodeNatsAddress, metadata.NodePort))
if err != nil {
fmt.Fprintf(os.Stderr, "failed to connect to shared NATS: %s", err)
return nil, err
Expand Down Expand Up @@ -77,7 +77,7 @@ func (a *Agent) Start() error {
return err
}

subject := fmt.Sprintf("agentint.%s.workdispatch", a.md.VmId)
subject := fmt.Sprintf("agentint.%s.workdispatch", *a.md.VmId)
_, err = a.nc.Subscribe(subject, a.handleWorkDispatched)
if err != nil {
a.LogError(fmt.Sprintf("Failed to subscribe to work dispatch: %s", err))
Expand Down Expand Up @@ -124,6 +124,11 @@ func (a *Agent) handleWorkDispatched(m *nats.Msg) {
return
}

if !request.Validate() {
_ = a.workAck(m, false, fmt.Sprintf("%v", request.Errors)) // FIXME-- this message can be formatted prettier
return
}

tmpFile, err := a.cacheExecutableArtifact(&request)
if err != nil {
_ = a.workAck(m, false, err.Error())
Expand Down Expand Up @@ -167,7 +172,7 @@ func (a *Agent) handleWorkDispatched(m *nats.Msg) {
func (a *Agent) cacheExecutableArtifact(req *agentapi.WorkRequest) (*string, error) {
tempFile := path.Join(os.TempDir(), "workload") // FIXME-- randomly generate a filename

err := a.cacheBucket.GetFile(req.WorkloadName, tempFile)
err := a.cacheBucket.GetFile(*req.WorkloadName, tempFile)
if err != nil {
msg := fmt.Sprintf("Failed to write workload artifact to temp dir: %s", err)
a.LogError(msg)
Expand All @@ -187,12 +192,20 @@ func (a *Agent) cacheExecutableArtifact(req *agentapi.WorkRequest) (*string, err
// newExecutionProviderParams initializes new execution provider params
// for the given work request and starts a goroutine listening
func (a *Agent) newExecutionProviderParams(req *agentapi.WorkRequest, tmpFile string) (*agentapi.ExecutionProviderParams, error) {
if a.md.VmId == nil {
return nil, errors.New("vm id is required to initialize execution provider params")
}

if req.WorkloadName == nil {
return nil, errors.New("workload name is required to initialize execution provider params")
}

params := &agentapi.ExecutionProviderParams{
WorkRequest: *req,
Stderr: &logEmitter{stderr: true, name: req.WorkloadName, logs: a.agentLogs},
Stdout: &logEmitter{stderr: false, name: req.WorkloadName, logs: a.agentLogs},
TmpFilename: tmpFile,
VmID: a.md.VmId,
Stderr: &logEmitter{stderr: true, name: *req.WorkloadName, logs: a.agentLogs},
Stdout: &logEmitter{stderr: false, name: *req.WorkloadName, logs: a.agentLogs},
TmpFilename: &tmpFile,
VmID: *a.md.VmId,

Fail: make(chan bool),
Run: make(chan bool),
Expand All @@ -205,17 +218,17 @@ func (a *Agent) newExecutionProviderParams(req *agentapi.WorkRequest, tmpFile st
for {
select {
case <-params.Fail:
msg := fmt.Sprintf("Failed to start workload: %s; vm: %s", params.WorkloadName, params.VmID)
a.PublishWorkloadExited(params.VmID, params.WorkloadName, msg, true, -1)
msg := fmt.Sprintf("Failed to start workload: %s; vm: %s", *params.WorkloadName, params.VmID)
a.PublishWorkloadExited(params.VmID, *params.WorkloadName, msg, true, -1)
return

case <-params.Run:
a.PublishWorkloadStarted(params.VmID, params.WorkloadName, params.TotalBytes)
a.PublishWorkloadStarted(params.VmID, *params.WorkloadName, params.TotalBytes)
sleepMillis = workloadExecutionSleepTimeoutMillis

case exit := <-params.Exit:
msg := fmt.Sprintf("Exited workload: %s; vm: %s; status: %d", params.WorkloadName, params.VmID, exit)
a.PublishWorkloadExited(params.VmID, params.WorkloadName, msg, exit != 0, exit)
msg := fmt.Sprintf("Exited workload: %s; vm: %s; status: %d", *params.WorkloadName, params.VmID, exit)
a.PublishWorkloadExited(params.VmID, *params.WorkloadName, msg, exit != 0, exit)
return
default:
// no-op
Expand All @@ -233,7 +246,7 @@ func (a *Agent) newExecutionProviderParams(req *agentapi.WorkRequest, tmpFile st
func (a *Agent) workAck(m *nats.Msg, accepted bool, msg string) error {
ack := agentapi.WorkResponse{
Accepted: accepted,
Message: msg,
Message: agentapi.StringOrNil(msg),
}

bytes, err := json.Marshal(&ack)
Expand Down
2 changes: 1 addition & 1 deletion nex-agent/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (a *Agent) submitLog(msg string, lvl agentapi.LogLevel) {
}

// FIXME-- revisit error handling
func (a *Agent) PublishWorkloadStarted(vmID, workloadName string, totalBytes int32) {
func (a *Agent) PublishWorkloadStarted(vmID, workloadName string, totalBytes *int32) {
select {
case a.agentLogs <- &agentapi.LogEntry{
Source: NexEventSourceNexAgent,
Expand Down
4 changes: 2 additions & 2 deletions nex-agent/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (a *Agent) dispatchEvents() {
continue
}

subject := fmt.Sprintf("agentint.%s.events.%s", a.md.VmId, entry.Type())
subject := fmt.Sprintf("agentint.%s.events.%s", *a.md.VmId, entry.Type())
err = a.nc.Publish(subject, bytes)
if err != nil {
continue
Expand All @@ -64,7 +64,7 @@ func (a *Agent) dispatchLogs() {
continue
}

subject := fmt.Sprintf("agentint.%s.logs", a.md.VmId)
subject := fmt.Sprintf("agentint.%s.logs", *a.md.VmId)
err = a.nc.Publish(subject, bytes)
if err != nil {
continue
Expand Down
8 changes: 4 additions & 4 deletions nex-agent/providers/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ type ExecutionProvider interface {

// NewExecutionProvider initializes and returns an execution provider for a given work request
func NewExecutionProvider(params *agentapi.ExecutionProviderParams) (ExecutionProvider, error) {
if params.WorkloadType == "" { // FIXME-- should req.WorkloadType be a *string for better readability? e.g., json.Unmarshal will set req.Type == "" even if it is not provided.
if params.WorkloadType == nil {
return nil, errors.New("execution provider factory requires a workload type parameter")
}

switch params.WorkloadType {
switch *params.WorkloadType {
case NexExecutionProviderELF:
return lib.InitNexExecutionProviderELF(params), nil
return lib.InitNexExecutionProviderELF(params)
case NexExecutionProviderV8:
return lib.InitNexExecutionProviderV8(params), nil
return lib.InitNexExecutionProviderV8(params)
case NexExecutionProviderOCI:
// TODO-- return lib.InitNexExecutionProviderOCI(params), nil
return nil, errors.New("oci execution provider not yet implemented")
Expand Down
22 changes: 17 additions & 5 deletions nex-agent/providers/lib/elf.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,24 @@ func (e *ELF) Validate() error {
}

// convenience method to initialize an ELF execution provider
func InitNexExecutionProviderELF(params *agentapi.ExecutionProviderParams) *ELF {
func InitNexExecutionProviderELF(params *agentapi.ExecutionProviderParams) (*ELF, error) {
if params.WorkloadName == nil {
return nil, errors.New("ELF execution provider requires a workload name parameter")
}

if params.TmpFilename == nil {
return nil, errors.New("ELF execution provider requires a temporary filename parameter")
}

if params.TotalBytes == nil {
return nil, errors.New("ELF execution provider requires a VM id parameter")
}

return &ELF{
environment: params.Environment,
name: params.WorkloadName,
tmpFilename: params.TmpFilename,
totalBytes: params.TotalBytes,
name: *params.WorkloadName,
tmpFilename: *params.TmpFilename,
totalBytes: *params.TotalBytes,
vmID: params.VmID,

stderr: params.Stderr,
Expand All @@ -94,7 +106,7 @@ func InitNexExecutionProviderELF(params *agentapi.ExecutionProviderParams) *ELF
fail: params.Fail,
run: params.Run,
exit: params.Exit,
}
}, nil
}

// Validates that the indicated file is a 64-bit linux native elf binary that is statically linked.
Expand Down
23 changes: 18 additions & 5 deletions nex-agent/providers/lib/v8.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lib

import (
"errors"
"fmt"
"os"
"time"
Expand Down Expand Up @@ -109,12 +110,24 @@ func (v *V8) Validate() error {
}

// convenience method to initialize a V8 execution provider
func InitNexExecutionProviderV8(params *agentapi.ExecutionProviderParams) *V8 {
func InitNexExecutionProviderV8(params *agentapi.ExecutionProviderParams) (*V8, error) {
if params.WorkloadName == nil {
return nil, errors.New("V8 execution provider requires a workload name parameter")
}

if params.TmpFilename == nil {
return nil, errors.New("V8 execution provider requires a temporary filename parameter")
}

if params.TotalBytes == nil {
return nil, errors.New("V8 execution provider requires a VM id parameter")
}

return &V8{
environment: params.Environment,
name: params.WorkloadName,
tmpFilename: params.TmpFilename,
totalBytes: params.TotalBytes,
name: *params.WorkloadName,
tmpFilename: *params.TmpFilename,
totalBytes: *params.TotalBytes,
vmID: params.VmID,

// stderr: params.Stderr,
Expand All @@ -125,5 +138,5 @@ func InitNexExecutionProviderV8(params *agentapi.ExecutionProviderParams) *V8 {
exit: params.Exit,

ctx: v8.NewContext(),
}
}, nil
}
2 changes: 1 addition & 1 deletion nex-node/agentcomms.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func handleHandshake(mgr *MachineManager) func(m *nats.Msg) {
}

now := time.Now().UTC()
mgr.handshakes[shake.MachineId] = now.Format(time.RFC3339)
mgr.handshakes[*shake.MachineId] = now.Format(time.RFC3339)

mgr.log.WithField("vmid", shake.MachineId).WithField("message", shake.Message).Info("Received agent handshake")
err = m.Respond([]byte("OK"))
Expand Down
8 changes: 6 additions & 2 deletions nex-node/cmd/nex-node/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ func main() {
}

func cmdUp(opts *nexnode.CliOptions, ctx context.Context, cancel context.CancelFunc, log *logrus.Logger) {

nc, err := generateConnectionFromOpts(opts)
if err != nil {
log.WithError(err).Error("Failed to connect to NATS")
Expand All @@ -99,7 +98,12 @@ func cmdUp(opts *nexnode.CliOptions, ctx context.Context, cancel context.CancelF

log.Infof("Loaded node configuration from '%s'", opts.NodeConfigFile)

manager := nexnode.NewMachineManager(ctx, cancel, nc, config, log)
manager, err := nexnode.NewMachineManager(ctx, cancel, nc, config, log)
if err != nil {
log.WithError(err).Error("Failed to initialize machine manager")
os.Exit(1)
}

err = manager.Start()
if err != nil {
log.WithError(err).Error("Failed to start machine manager")
Expand Down
Loading

0 comments on commit ab19f49

Please sign in to comment.