From e9d2a976b391d0f505cd404ca115517ce9d675ef Mon Sep 17 00:00:00 2001 From: oshima Date: Wed, 20 Jun 2018 19:28:05 +0900 Subject: [PATCH] Hyperband (#124) * add hyp Signed-off-by: YujiOshima * fix hyperband suggestion Signed-off-by: YujiOshima * add test and docs * fix test Signed-off-by: YujiOshima * fix test Signed-off-by: YujiOshima --- examples/MinikubeDemo.md | 8 + examples/MinikubeDemo/deploy.sh | 1 + .../suggestion/hyperband/deployment.yaml | 30 + .../vizier/suggestion/hyperband/service.yaml | 17 + examples/hyperband-example-client.go | 320 ++++++++ examples/study-config.yml | 5 + examples/suggestion-config-hyb.yml | 11 + pkg/suggestion/hyperband_service.go | 727 ++++++++++-------- scripts/build.sh | 4 +- scripts/deploy.sh | 1 + test/e2e/study-config.yml | 6 + ...-config.yml => suggestion-config-grid.yml} | 0 test/e2e/suggestion-config-hyb.yml | 12 + test/e2e/test-client.go | 134 +++- test/scripts/build-suggestion-hyperband.sh | 41 + test/scripts/run-tests.sh | 5 +- test/workflows/components/workflows.libsonnet | 7 + 17 files changed, 987 insertions(+), 342 deletions(-) create mode 100644 examples/MinikubeDemo/manifests/vizier/suggestion/hyperband/deployment.yaml create mode 100644 examples/MinikubeDemo/manifests/vizier/suggestion/hyperband/service.yaml create mode 100644 examples/hyperband-example-client.go create mode 100644 examples/suggestion-config-hyb.yml rename test/e2e/{suggestion-config.yml => suggestion-config-grid.yml} (100%) create mode 100644 test/e2e/suggestion-config-hyb.yml create mode 100755 test/scripts/build-suggestion-hyperband.sh diff --git a/examples/MinikubeDemo.md b/examples/MinikubeDemo.md index 8ac1958cb12..8c1f291d30d 100644 --- a/examples/MinikubeDemo.md +++ b/examples/MinikubeDemo.md @@ -66,6 +66,14 @@ go run client-example.go -a grid ``` In this demo, make 4 grids for learning rate (--lr) Min 0.03 and Max 0.07. +### Hyperband Demo +As the Hyperband suggestion is so different from random and grid, use special client example. +``` +go run hyperband-example-client.go +``` +The parametes of Hyperband are defined [suggestion-config-hyb.yml](./suggestion-config-hyb.yml). +In this demo, the eta is 3 and the R is 9. + ## UI You can check your Model with Web UI. Acsess to `http://127.0.0.1:3000/` diff --git a/examples/MinikubeDemo/deploy.sh b/examples/MinikubeDemo/deploy.sh index 201ab8a2762..61f54aee657 100755 --- a/examples/MinikubeDemo/deploy.sh +++ b/examples/MinikubeDemo/deploy.sh @@ -10,4 +10,5 @@ kubectl apply -f manifests/vizier/db kubectl apply -f manifests/vizier/core kubectl apply -f manifests/vizier/suggestion/random kubectl apply -f manifests/vizier/suggestion/grid +kubectl apply -f manifests/vizier/suggestion/hyperband kubectl apply -f manifests/vizier/earlystopping/medianstopping diff --git a/examples/MinikubeDemo/manifests/vizier/suggestion/hyperband/deployment.yaml b/examples/MinikubeDemo/manifests/vizier/suggestion/hyperband/deployment.yaml new file mode 100644 index 00000000000..e6ef49f0d08 --- /dev/null +++ b/examples/MinikubeDemo/manifests/vizier/suggestion/hyperband/deployment.yaml @@ -0,0 +1,30 @@ +apiVersion: extensions/v1beta1 +kind: Deployment +metadata: + name: vizier-suggestion-hayperband + namespace: katib + labels: + app: vizier + component: suggestion-hayperband +spec: + replicas: 1 + template: + metadata: + name: vizier-suggestion-hayperband + labels: + app: vizier + component: suggestion-hayperband + spec: + containers: + - name: vizier-suggestion-hayperband + image: katib/suggestion-hayperband + ports: + - name: api + containerPort: 6789 +# resources: +# requests: +# cpu: 500m +# memory: 500M +# limits: +# cpu: 500m +# memory: 500M diff --git a/examples/MinikubeDemo/manifests/vizier/suggestion/hyperband/service.yaml b/examples/MinikubeDemo/manifests/vizier/suggestion/hyperband/service.yaml new file mode 100644 index 00000000000..77d0c544db3 --- /dev/null +++ b/examples/MinikubeDemo/manifests/vizier/suggestion/hyperband/service.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Service +metadata: + name: vizier-suggestion-hayperband + namespace: katib + labels: + app: vizier + component: suggestion-hayperband +spec: + type: ClusterIP + ports: + - port: 6789 + protocol: TCP + name: api + selector: + app: vizier + component: suggestion-hayperband diff --git a/examples/hyperband-example-client.go b/examples/hyperband-example-client.go new file mode 100644 index 00000000000..7302770db52 --- /dev/null +++ b/examples/hyperband-example-client.go @@ -0,0 +1,320 @@ +package main + +import ( + "context" + "flag" + "io/ioutil" + "log" + "time" + + "github.com/kubeflow/katib/pkg/api" + "google.golang.org/grpc" + "gopkg.in/yaml.v2" +) + +var managerAddr = flag.String("s", "127.0.0.1:6789", "Endpoint of manager default 127.0.0.1:6789") +var suggestArgo = flag.String("a", "hyperband", "Suggestion Algorithm (random, grid)") +var requestnum = flag.Int("r", 2, "Request number for random Suggestions (default: 2)") +var suggestionConfFile = flag.String("c", "suggestion-config-hyb.yml", "File path to suggestion config.") + +var studyConfig = api.StudyConfig{} +var workerConfig = api.WorkerConfig{} +var suggestionConfig = api.SetSuggestionParametersRequest{} + +var trials = map[string]*api.Trial{} + +func main() { + readConfigs() + conn, err := grpc.Dial(*managerAddr, grpc.WithInsecure()) + if err != nil { + log.Fatalf("could not connect: %v", err) + } + defer conn.Close() + ctx := context.Background() + c := api.NewManagerClient(conn) + + //CreateStudy + studyId := CreateStudy(c) + + //SetSuggestParam + paramId := setSuggestionParam(c, studyId) + + //Loop until end of HyperBand Algorithm + for true { + //GetSuggestion + getSuggestReply := getSuggestion(c, studyId, paramId) + checkSuggestions(getSuggestReply) + if len(getSuggestReply.Trials) == 0 { + log.Printf("Hyperband ended") + break + } + //RunTrials + workerIds := runTrials(c, studyId, getSuggestReply) + for !isCompletedAllWorker(c, studyId) { + time.Sleep(10 * time.Second) + getMetricsRequest := &api.GetMetricsRequest{ + StudyId: studyId, + WorkerIds: workerIds, + } + //GetMetrics + getMetricsReply, err := c.GetMetrics(ctx, getMetricsRequest) + if err != nil { + continue + } + //Save or Update model on ModelDB + SaveOrUpdateModel(c, getMetricsReply) + } + checkWorkersResult(c, studyId) + } + conn.Close() + log.Println("E2E test OK!") +} + +func readConfigs() { + flag.Parse() + buf, err := ioutil.ReadFile("study-config.yml") + if err != nil { + log.Fatalf("fail to read study-config yaml") + } + err = yaml.Unmarshal(buf, &studyConfig) + if err != nil { + log.Fatalf("fail to Unmarshal yaml") + } + + buf, err = ioutil.ReadFile("worker-config.yml") + if err != nil { + log.Fatalf("fail to read worker-config yaml") + } + err = yaml.Unmarshal(buf, &workerConfig) + if err != nil { + log.Fatalf("fail to Unmarshal yaml") + } + + if *suggestionConfFile != "" { + buf, err = ioutil.ReadFile(*suggestionConfFile) + if err != nil { + log.Fatalf("fail to read suggestion-config yaml") + } + } + err = yaml.Unmarshal(buf, &suggestionConfig) + if err != nil { + log.Fatalf("fail to Unmarshal yaml") + } +} + +func CreateStudy(c api.ManagerClient) string { + ctx := context.Background() + createStudyreq := &api.CreateStudyRequest{ + StudyConfig: &studyConfig, + } + createStudyreply, err := c.CreateStudy(ctx, createStudyreq) + if err != nil { + log.Fatalf("StudyConfig Error %v", err) + } + studyId := createStudyreply.StudyId + log.Printf("Study ID %s", studyId) + getStudyreq := &api.GetStudyRequest{ + StudyId: studyId, + } + getStudyReply, err := c.GetStudy(ctx, getStudyreq) + if err != nil { + log.Fatalf("GetConfig Error %v", err) + } + log.Printf("Study ID %s StudyConf %v", studyId, getStudyReply.StudyConfig) + return studyId +} + +func setSuggestionParam(c api.ManagerClient, studyId string) string { + ctx := context.Background() + switch *suggestArgo { + case "random": + return "" + case "grid": + suggestionConfig.StudyId = studyId + setSuggesitonParameterReply, err := c.SetSuggestionParameters(ctx, &suggestionConfig) + if err != nil { + log.Fatalf("SetConfig Error %v", err) + } + log.Printf("Grid suggestion prameter ID %s", setSuggesitonParameterReply.ParamId) + return setSuggesitonParameterReply.ParamId + case "hyperband": + suggestionConfig.StudyId = studyId + setSuggesitonParameterReply, err := c.SetSuggestionParameters(ctx, &suggestionConfig) + if err != nil { + log.Fatalf("SetConfig Error %v", err) + } + log.Printf("HyperBand suggestion prameter ID %s", setSuggesitonParameterReply.ParamId) + return setSuggesitonParameterReply.ParamId + } + return "" + +} + +func getSuggestion(c api.ManagerClient, studyId string, paramId string) *api.GetSuggestionsReply { + ctx := context.Background() + var getSuggestRequest *api.GetSuggestionsRequest + switch *suggestArgo { + case "random": + //Random suggestion doesn't need suggestion parameter + getSuggestRequest = &api.GetSuggestionsRequest{ + StudyId: studyId, + SuggestionAlgorithm: "random", + RequestNumber: int32(*requestnum), + } + + case "grid": + getSuggestRequest = &api.GetSuggestionsRequest{ + StudyId: studyId, + SuggestionAlgorithm: "grid", + RequestNumber: 0, + //RequestNumber=0 means get all grids. + ParamId: paramId, + } + case "hyperband": + getSuggestRequest = &api.GetSuggestionsRequest{ + StudyId: studyId, + SuggestionAlgorithm: "hyperband", + RequestNumber: 0, + ParamId: paramId, + } + } + + getSuggestReply, err := c.GetSuggestions(ctx, getSuggestRequest) + if err != nil { + log.Fatalf("GetSuggestion Error %v", err) + } + log.Println("Get " + *suggestArgo + " Suggestions:") + for _, t := range getSuggestReply.Trials { + log.Printf("%v", t) + } + return getSuggestReply +} + +func checkSuggestions(getSuggestReply *api.GetSuggestionsReply) bool { + switch *suggestArgo { + case "random": + if len(getSuggestReply.Trials) != *requestnum { + log.Fatalf("Number of Random suggestion incrrect. Expected %d Got %d", *requestnum, len(getSuggestReply.Trials)) + } + case "grid": + if len(getSuggestReply.Trials) != 4 { + log.Fatalf("Number of Grid suggestion incrrect. Expected %d Got %d", 4, len(getSuggestReply.Trials)) + } + } + log.Println("Check suggestion passed!") + return true +} + +func runTrials(c api.ManagerClient, studyId string, getSuggestReply *api.GetSuggestionsReply) []string { + ctx := context.Background() + workerIds := make([]string, len(getSuggestReply.Trials)) + workerParameter := make(map[string][]*api.Parameter) + for i, t := range getSuggestReply.Trials { + wc := workerConfig + rtr := &api.RunTrialRequest{ + StudyId: studyId, + TrialId: t.TrialId, + Runtime: "kubernetes", + WorkerConfig: &wc, + } + for _, p := range t.ParameterSet { + rtr.WorkerConfig.Command = append(rtr.WorkerConfig.Command, p.Name+"="+p.Value) + } + workerReply, err := c.RunTrial(ctx, rtr) + if err != nil { + log.Fatalf("RunTrial Error %v", err) + } + workerIds[i] = workerReply.WorkerId + workerParameter[workerReply.WorkerId] = t.ParameterSet + saveModelRequest := &api.SaveModelRequest{ + Model: &api.ModelInfo{ + StudyName: studyConfig.Name, + WorkerId: workerReply.WorkerId, + Parameters: t.ParameterSet, + Metrics: []*api.Metrics{}, + ModelPath: "pvc:/Path/to/Model", + }, + DataSet: &api.DataSetInfo{ + Name: "Mnist", + Path: "/path/to/data", + }, + } + _, err = c.SaveModel(ctx, saveModelRequest) + if err != nil { + log.Fatalf("SaveModel Error %v", err) + } + log.Printf("WorkerID %s start\n", workerReply.WorkerId) + trials[workerReply.WorkerId] = t + } + return workerIds +} + +func SaveOrUpdateModel(c api.ManagerClient, getMetricsReply *api.GetMetricsReply) { + ctx := context.Background() + for _, mls := range getMetricsReply.MetricsLogSets { + if len(mls.MetricsLogs) > 0 { + log.Printf("WorkerID %s :", mls.WorkerId) + //Only Metrics can be updated. + saveModelRequest := &api.SaveModelRequest{ + Model: &api.ModelInfo{ + StudyName: studyConfig.Name, + WorkerId: mls.WorkerId, + Metrics: []*api.Metrics{}, + }, + } + for _, ml := range mls.MetricsLogs { + if len(ml.Values) > 0 { + log.Printf("\t Metrics Name %s Value %v", ml.Name, ml.Values[len(ml.Values)-1]) + saveModelRequest.Model.Metrics = append(saveModelRequest.Model.Metrics, &api.Metrics{Name: ml.Name, Value: ml.Values[len(ml.Values)-1]}) + } + } + _, err := c.SaveModel(ctx, saveModelRequest) + if err != nil { + log.Fatalf("SaveModel Error %v", err) + } + } + } +} + +func isCompletedAllWorker(c api.ManagerClient, studyId string) bool { + ctx := context.Background() + getWorkerRequest := &api.GetWorkersRequest{StudyId: studyId} + getWorkerReply, err := c.GetWorkers(ctx, getWorkerRequest) + if err != nil { + log.Fatalf("GetWorker Error %v", err) + } + for _, w := range getWorkerReply.Workers { + if w.Status != api.State_COMPLETED { + return false + } + } + log.Println("All Worker Completed") + return true +} + +func checkWorkersResult(c api.ManagerClient, studyId string) bool { + ctx := context.Background() + getMetricsRequest := &api.GetMetricsRequest{ + StudyId: studyId, + } + //GetMetrics + getMetricsReply, err := c.GetMetrics(ctx, getMetricsRequest) + if err != nil { + log.Fatalf("Fataled to Get Metrics") + } + + for _, mls := range getMetricsReply.MetricsLogSets { + for _, p := range trials[mls.WorkerId].ParameterSet { + for _, ml := range mls.MetricsLogs { + if p.Name == ml.Name { + if p.Value != ml.Values[len(ml.Values)-1] { + log.Fatalf("Output %s is mismuched to Input %s", ml.Values[len(ml.Values)-1], p.Value) + return false + } + } + } + } + } + log.Println("Input Output check passed") + return true +} diff --git a/examples/study-config.yml b/examples/study-config.yml index 707c93d425f..289f6d90553 100644 --- a/examples/study-config.yml +++ b/examples/study-config.yml @@ -21,6 +21,11 @@ parameterconfigs: - sgd - adam - ftrl + - name: --num-epochs + parametertype: 2 + feasible: + min: "20" + max: "20" objectivevaluename: Validation-accuracy metrics: - accuracy diff --git a/examples/suggestion-config-hyb.yml b/examples/suggestion-config-hyb.yml new file mode 100644 index 00000000000..02656c77eaf --- /dev/null +++ b/examples/suggestion-config-hyb.yml @@ -0,0 +1,11 @@ +suggetionalgorithm: "hyperband" +suggestionparameters: + - + name: "ResourceName" + value: "--num-epochs" + - + name: "eta" + value: "3" + - + name: "r_l" + value: "9" diff --git a/pkg/suggestion/hyperband_service.go b/pkg/suggestion/hyperband_service.go index 682cf0bbbee..efc397182db 100644 --- a/pkg/suggestion/hyperband_service.go +++ b/pkg/suggestion/hyperband_service.go @@ -2,326 +2,443 @@ package suggestion import ( "context" - // "crypto/rand" - // "fmt" - // "github.com/kubeflow/katib/pkg/db" - // "log" - // "math" - // "sort" - // "strconv" + "fmt" + "log" + "math" + "sort" + "strconv" + "strings" "github.com/kubeflow/katib/pkg/api" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) -//type Evals struct { -// id string -// value float64 -//} -//type Bracket []Evals -// -//func (b Bracket) Len() int { -// return len(b) -//} -// -//func (b Bracket) Swap(i, j int) { -// b[i], b[j] = b[j], b[i] -//} -// -//func (b Bracket) Less(i, j int) bool { -// return b[i].value > b[j].value -//} -// +type Evals struct { + id string + value float64 +} +type Bracket []Evals + +func (b Bracket) Len() int { + return len(b) +} + +func (b Bracket) Swap(i, j int) { + b[i], b[j] = b[j], b[i] +} + +func (b Bracket) Less(i, j int) bool { + return b[i].value > b[j].value +} + type HyperBandParameters struct { - eta float64 - sMax int - b_l float64 - r_l float64 - r float64 - n int - shloopitr int - currentS int - ResourceName string + eta float64 + sMax int + b_l float64 + r_l float64 + r float64 + n int + shloopitr int + currentS int + ResourceName string + ObjectiveValueName string + evaluatingTrials []string } type HyperBandSuggestService struct { RandomSuggestService - parameters HyperBandParameters } func NewHyperBandSuggestService() *HyperBandSuggestService { return &HyperBandSuggestService{} } -// -//func (h *HyperBandSuggestService) generate_randid() string { -// // UUID isn't quite handy in the Go world -// id_ := make([]byte, 8) -// _, err := rand.Read(id_) -// if err != nil { -// log.Fatalf("Error reading random: %v", err) -// } -// return fmt.Sprintf("%016x", id_) -//} -// -//func (h *HyperBandSuggestService) makeMasterBracket(sconf *api.StudyConfig, n int) Bracket { -// log.Printf("Make MasterBracket %v Trials", n) -// s_t := make([]*api.Trial, n) -// for i := 0; i < n; i++ { -// s_t[i] = &api.Trial{} -// s_t[i].ParameterSet = make([]*api.Parameter, len(sconf.ParameterConfigs.Configs)) -// for j, pc := range sconf.ParameterConfigs.Configs { -// s_t[i].ParameterSet[j] = &api.Parameter{Name: pc.Name} -// s_t[i].ParameterSet[j].ParameterType = pc.ParameterType -// switch pc.ParameterType { -// case api.ParameterType_INT: -// imin, _ := strconv.Atoi(pc.Feasible.Min) -// imax, _ := strconv.Atoi(pc.Feasible.Max) -// s_t[i].ParameterSet[j].Value = strconv.Itoa(h.IntRandom(imin, imax)) -// case api.ParameterType_DOUBLE: -// dmin, _ := strconv.ParseFloat(pc.Feasible.Min, 64) -// dmax, _ := strconv.ParseFloat(pc.Feasible.Max, 64) -// s_t[i].ParameterSet[j].Value = strconv.FormatFloat(h.DoubelRandom(dmin, dmax), 'f', 4, 64) -// case api.ParameterType_CATEGORICAL: -// s_t[i].ParameterSet[j].Value = pc.Feasible.List[h.IntRandom(0, len(pc.Feasible.List)-1)] -// } -// } -// s_t[i].Tags = append(s_t[i].Tags, &api.Tag{Name: "HyperBand_BracketID", Value: h.generate_randid()}) -// } -// return Bracket(s_t) -//} -// -//func (h *HyperBandSuggestService) purseSuggestionParameters(sparam []*api.SuggestionParameters) (HyperBandParameters, error) { -// p := &HyperBandParameters{ -// eta: -1, -// sMax: -1, -// b_l: -1, -// r_l: -1, -// r: -1, -// n: -1, -// shloopitr: -1, -// currentS: -1, -// ResourceName: -1, -// } -// for _, sp := range sparam { -// switch sp.Name { -// case "Eta": -// p.eta, _ = strconv.ParseFloat(sp.Value, 64) -// case "R": -// p.r_l, _ = strconv.ParseFloat(sp.Value, 64) -// case "ResourceName": -// p.ResourceName = sp.Value -// case "b_l": -// p.b_l, _ = strconv.ParseFloat(sp.Value, 64) -// case "sMax": -// p.sMax, _ = strconv.AtoI(sp.Value) -// case "r_s": -// p.r, _ = strconv.ParseFloat(sp.Value, 64) -// case "n_s": -// p.n, _ = strconv.AtoI(sp.Value) -// case "shloopitr": -// p.shloopitr, _ = strconv.AtoI(sp.Value) -// case "currentS": -// p.currentS, _ = strconv.AtoI(sp.Value) -// default: -// log.Printf("Unknown Suggestion Parameter %v", sp.Name) -// } -// } -// if p.eta == 0 || p.r_l == 0 || p.ResourceName == "" { -// log.Printf("Failed to Suggestion Parameter set.") -// return &api.SetSuggestionParametersReply{}, fmt.Errorf("Suggestion Parameter set Error") -// } -// if p.sMax == -1 { -// p.sMax = int(math.Log(p.r_l) / math.Log(p.eta)) -// } -// if p.b_l == -1 { -// p.b_l = float64((p.sMax + 1.0)) * p.r_l -// } -// if p.n == -1 { -// p.n = int((p.b_l/p.r_l)*(math.Pow(p.eta, float64(p.sMax))/float64(p.sMax+1))) + 1 -// } -// if p.currentS == -1 { -// p.currentS = p.sMax + 1 -// } -// if p.shloopit == -1 { -// p.shloopitr = p.currentS + 1 -// } -// if p.r == -1 { -// p.r = p.r_l * math.Pow(p.eta, float64(-p.sMax)) -// } -// p.MasterBracket = h.makeMasterBracket(in.Configs, p.n) -// h.parameters[in.StudyId] = p -// log.Printf("Smax = %v", p.sMax) -// return &api.SetSuggestionParametersReply{}, nil -//} -// -//func (h *HyperBandSuggestService) getHyperParameter(studyId string, sconf *api.StudyConfig, n int) Bracket { -// s_t := make([]*api.Trial, n) -// for i := 0; i < n; i++ { -// s_t[i] = &api.Trial{} -// s_t[i].ParameterSet = make([]*api.Parameter, len(sconf.ParameterConfigs.Configs)) -// s_t[i].Status = api.TrialState_PENDING -// s_t[i].EvalLogs = make([]*api.EvaluationLog, 0) -// var j int -// if sconf.OptimizationType == api.OptimizationType_MAXIMIZE { -// j = i -// } else if sconf.OptimizationType == api.OptimizationType_MINIMIZE { -// j = len(h.parameters[studyId].MasterBracket) - 1 - i -// } -// for k, v := range h.parameters[studyId].MasterBracket[j].ParameterSet { -// s_t[i].ParameterSet[k] = v -// } -// for _, t := range h.parameters[studyId].MasterBracket[j].Tags { -// s_t[i].Tags = append(s_t[i].Tags, t) -// } -// } -// return Bracket(s_t) -//} -// -//func (h *HyperBandSuggestService) hbLoopParamUpdate(studyId string) { -// log.Printf("HB loop s = %v", h.parameters[studyId].currentS) -// h.parameters[studyId].shloopitr = 0 -// h.parameters[studyId].n = int((h.parameters[studyId].b_l/h.parameters[studyId].r_l)*(math.Pow(h.parameters[studyId].eta, float64(h.parameters[studyId].currentS))/float64(h.parameters[studyId].currentS+1))) + 1 -// h.parameters[studyId].r = h.parameters[studyId].r_l * math.Pow(h.parameters[studyId].eta, float64(-h.parameters[studyId].currentS)) -//} -// -//func (h *HyperBandSuggestService) shLoopParamUpdate(studyId string) (int, int) { -// log.Printf("SH loop i = %v", h.parameters[studyId].shloopitr) -// pn_i := int(float64(h.parameters[studyId].n) * math.Pow(h.parameters[studyId].eta, float64(-h.parameters[studyId].shloopitr+1)) / h.parameters[studyId].eta) -// r_i := int(h.parameters[studyId].r * math.Pow(h.parameters[studyId].eta, float64(h.parameters[studyId].shloopitr))) -// return pn_i, r_i -//} -// -//func (h *HyperBandSuggestService) makeBracker(ctx context.Context, cl *api.ManagerClient, sid string, logs []string) (Bracket, error) { -// req := &api.GetWorkersRequest{StudyId: sid} -// r, err := api.GetWorkers(ctx, req) -// if err != nil { -// return nil, err -// } -// mreq := &api.GetMetrics{ -// StudyId: sid, -// WorkerId: logs, -// MetricsNames: []string{h.ResourceName}, -// } -// mr, err := api.GetMetrics(ctx, mreq) -// if err != nil { -// return nil, err -// } -// e_l := make([]Evals, len(logs)) -// for i, l := range logs { -// for _, m := range mr.MetricsLogs { -// if m.WorkerId == l { -// if len(m.MetricsLogs) == 0 { -// break -// } -// e_l[i].Value = strconv.ParseFloat(m.MetricsLogs[len(m.MetricsLogs)-1].Value, 64) -// } -// } -// for _, w := range r.Workers { -// if w.WorkerId == l { -// e_l[i].Id = w.TrialId -// } -// } -// } -// return Bracket(e_l) -//} +func (h *HyperBandSuggestService) makeBracket(ctx context.Context, c api.ManagerClient, studyId string, n int, r float64, hbparam *HyperBandParameters) ([]string, []*api.Trial, error) { + if len(hbparam.evaluatingTrials) == 0 || hbparam.shloopitr == 0 { + return h.makeMasterBracket(ctx, c, studyId, n, r, hbparam) + } else { + err, b := h.evalWorkers(ctx, c, studyId, hbparam) + if err != nil { + return nil, nil, err + } + if b == nil { + return nil, nil, nil + } + return h.makeChildBracket(ctx, c, b, studyId, n, r, hbparam) + } +} + +func (h *HyperBandSuggestService) makeMasterBracket(ctx context.Context, c api.ManagerClient, studyId string, n int, r float64, hbparam *HyperBandParameters) ([]string, []*api.Trial, error) { + log.Printf("Make MasterBracket %v Trials", n) + gsreq := &api.GetStudyRequest{ + StudyId: studyId, + } + gsrep, err := c.GetStudy(ctx, gsreq) + if err != nil { + log.Printf("GetStudy Error") + return nil, nil, err + } + sconf := gsrep.StudyConfig + tids := make([]string, n) + ts := make([]*api.Trial, n) + for i := 0; i < n; i++ { + t := &api.Trial{ + StudyId: studyId, + } + t.ParameterSet = make([]*api.Parameter, len(sconf.ParameterConfigs.Configs)) + for j, pc := range sconf.ParameterConfigs.Configs { + t.ParameterSet[j] = &api.Parameter{Name: pc.Name} + t.ParameterSet[j].ParameterType = pc.ParameterType + if pc.Name == hbparam.ResourceName { + if pc.ParameterType == api.ParameterType_INT { + t.ParameterSet[j].Value = strconv.Itoa(int(r)) + } else { + t.ParameterSet[j].Value = strconv.FormatFloat(r, 'f', 4, 64) + } + } else { + switch pc.ParameterType { + case api.ParameterType_INT: + imin, _ := strconv.Atoi(pc.Feasible.Min) + imax, _ := strconv.Atoi(pc.Feasible.Max) + t.ParameterSet[j].Value = strconv.Itoa(h.IntRandom(imin, imax)) + case api.ParameterType_DOUBLE: + dmin, _ := strconv.ParseFloat(pc.Feasible.Min, 64) + dmax, _ := strconv.ParseFloat(pc.Feasible.Max, 64) + t.ParameterSet[j].Value = strconv.FormatFloat(h.DoubelRandom(dmin, dmax), 'f', 4, 64) + case api.ParameterType_CATEGORICAL: + t.ParameterSet[j].Value = pc.Feasible.List[h.IntRandom(0, len(pc.Feasible.List)-1)] + } + } + } + req := &api.CreateTrialRequest{ + Trial: t, + } + ret, err := c.CreateTrial(ctx, req) + if err != nil { + log.Printf("CreateTrial Error") + return nil, nil, err + } + tids[i] = ret.TrialId + t.TrialId = ret.TrialId + ts[i] = t + } + return tids, ts, nil +} + +func (h *HyperBandSuggestService) makeChildBracket(ctx context.Context, c api.ManagerClient, parent Bracket, studyId string, n int, r_i float64, hbparam *HyperBandParameters) ([]string, []*api.Trial, error) { + gsreq := &api.GetStudyRequest{ + StudyId: studyId, + } + gsrep, err := c.GetStudy(ctx, gsreq) + if err != nil { + log.Printf("GetStudy Error") + return nil, nil, err + } + sconf := gsrep.StudyConfig + child := Bracket{} + + if sconf.OptimizationType == api.OptimizationType_MINIMIZE { + child = parent[n:] + } else if sconf.OptimizationType == api.OptimizationType_MAXIMIZE { + child = parent[:n] + } + gtreq := &api.GetTrialsRequest{ + StudyId: studyId, + } + gtrep, err := c.GetTrials(ctx, gtreq) + if err != nil { + log.Printf("GetTrials Error") + return nil, nil, err + } + tids := make([]string, n) + ts := make([]*api.Trial, n) + var rtype api.ParameterType + for _, pc := range sconf.ParameterConfigs.Configs { + if pc.Name == hbparam.ResourceName { + rtype = pc.ParameterType + } + } + for i, tid := range child { + t := &api.Trial{ + StudyId: studyId, + } + for _, pt := range gtrep.Trials { + if pt.TrialId == tid.id { + t.ParameterSet = pt.ParameterSet + } + } + for i, p := range t.ParameterSet { + if p.Name == hbparam.ResourceName { + if rtype == api.ParameterType_INT { + t.ParameterSet[i].Value = strconv.Itoa(int(r_i)) + } else { + t.ParameterSet[i].Value = strconv.FormatFloat(r_i, 'f', 4, 64) + } + } + } + req := &api.CreateTrialRequest{ + Trial: t, + } + ret, err := c.CreateTrial(ctx, req) + if err != nil { + log.Printf("CreateTrial Error") + return nil, nil, err + } + tids[i] = ret.TrialId + t.TrialId = ret.TrialId + ts[i] = t + } + return tids, ts, nil +} + +func (h *HyperBandSuggestService) purseSuggestionParameters(ctx context.Context, c api.ManagerClient, studyId string, sparam []*api.SuggestionParameter) (*HyperBandParameters, error) { + p := &HyperBandParameters{ + eta: -1, + sMax: -1, + b_l: -1, + r_l: -1, + r: -1, + n: -1, + shloopitr: -1, + currentS: -1, + ResourceName: "", + ObjectiveValueName: "", + evaluatingTrials: []string{}, + } + for _, sp := range sparam { + switch sp.Name { + case "eta": + p.eta, _ = strconv.ParseFloat(sp.Value, 64) + case "r_l": + p.r_l, _ = strconv.ParseFloat(sp.Value, 64) + case "ResourceName": + p.ResourceName = sp.Value + case "ObjectiveValueName": + p.ObjectiveValueName = sp.Value + case "b_l": + p.b_l, _ = strconv.ParseFloat(sp.Value, 64) + case "sMax": + p.sMax, _ = strconv.Atoi(sp.Value) + case "r": + p.r, _ = strconv.ParseFloat(sp.Value, 64) + case "n": + p.n, _ = strconv.Atoi(sp.Value) + case "shloopitr": + p.shloopitr, _ = strconv.Atoi(sp.Value) + case "currentS": + p.currentS, _ = strconv.Atoi(sp.Value) + case "evaluatingTrials": + p.evaluatingTrials = strings.Split(sp.Value, ",") + default: + log.Printf("Unknown Suggestion Parameter %v", sp.Name) + } + } + if p.r_l <= 0 || p.ResourceName == "" { + log.Printf("Failed to purse Suggestion Parameter. r_l and ResourceName must be set.") + return nil, fmt.Errorf("Suggestion Parameter set Error") + } + if p.eta <= 0 { + p.eta = 3 + } + if p.ObjectiveValueName == "" { + gsreq := &api.GetStudyRequest{ + StudyId: studyId, + } + gsrep, err := c.GetStudy(ctx, gsreq) + if err != nil { + log.Printf("GetStudy Error") + return nil, err + } + p.ObjectiveValueName = gsrep.StudyConfig.ObjectiveValueName + } + if p.sMax == -1 { + p.sMax = int(math.Trunc(math.Log(p.r_l) / math.Log(p.eta))) + } + if p.b_l == -1 { + p.b_l = float64((p.sMax + 1.0)) * p.r_l + } + if p.n == -1 { + p.n = int(math.Ceil((p.b_l / p.r_l) * (math.Pow(p.eta, float64(p.sMax)) / float64(p.sMax+1)))) + } + if p.currentS == -1 { + p.currentS = p.sMax + } + if p.shloopitr == -1 { + p.shloopitr = 0 + } + if p.r == -1 { + p.r = p.r_l * math.Pow(p.eta, float64(-p.sMax)) + } + log.Printf("Hyb Param sMax %v", p.sMax) + log.Printf("Hyb Param B %v", p.b_l) + log.Printf("Hyb Param n %v", p.n) + log.Printf("Hyb Param currentS %v", p.currentS) + log.Printf("Hyb Param r %v", p.r) + log.Printf("Hyb Param evaluatingTrials %v", p.evaluatingTrials) + return p, nil +} + +func (h *HyperBandSuggestService) saveSuggestionParameters(ctx context.Context, c api.ManagerClient, studyId string, algorithm string, paramId string, hbparam *HyperBandParameters) error { + req := &api.SetSuggestionParametersRequest{ + StudyId: studyId, + SuggestionAlgorithm: algorithm, + ParamId: paramId, + } + sp := []*api.SuggestionParameter{} + sp = append(sp, &api.SuggestionParameter{ + Name: "eta", + Value: strconv.FormatFloat(hbparam.eta, 'f', 4, 64), + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "sMax", + Value: strconv.Itoa(hbparam.sMax), + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "b_l", + Value: strconv.FormatFloat(hbparam.b_l, 'f', 4, 64), + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "r_l", + Value: strconv.FormatFloat(hbparam.r_l, 'f', 4, 64), + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "r", + Value: strconv.FormatFloat(hbparam.r, 'f', 4, 64), + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "shloopitr", + Value: strconv.Itoa(hbparam.shloopitr), + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "n", + Value: strconv.Itoa(hbparam.n), + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "currentS", + Value: strconv.Itoa(hbparam.currentS), + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "ResourceName", + Value: hbparam.ResourceName, + }) + sp = append(sp, &api.SuggestionParameter{ + Name: "evaluatingTrials", + Value: strings.Join(hbparam.evaluatingTrials, ","), + }) + req.SuggestionParameters = sp + _, err := c.SetSuggestionParameters(ctx, req) + return err +} + +func (h *HyperBandSuggestService) evalWorkers(ctx context.Context, c api.ManagerClient, studyId string, hbparam *HyperBandParameters) (error, Bracket) { + bracket := Bracket{} + for _, tid := range hbparam.evaluatingTrials { + gwreq := &api.GetWorkersRequest{ + TrialId: tid, + } + gwrep, err := c.GetWorkers(ctx, gwreq) + if err != nil { + return err, nil + } + wl := make([]string, len(gwrep.Workers)) + for i, w := range gwrep.Workers { + wl[i] = w.WorkerId + } + gmreq := &api.GetMetricsRequest{ + StudyId: studyId, + WorkerIds: wl, + MetricsNames: []string{hbparam.ObjectiveValueName}, + } + gmrep, err := c.GetMetrics(ctx, gmreq) + if err != nil { + log.Printf("GetMetrics error %v", err) + return err, nil + } + vs := 0.0 + for _, ml := range gmrep.MetricsLogSets { + if ml.WorkerStatus != api.State_COMPLETED { + return nil, nil + } + v, _ := strconv.ParseFloat(ml.MetricsLogs[0].Values[len(ml.MetricsLogs[0].Values)-1], 64) + vs += v + } + if len(gwrep.Workers) > 0 { + bracket = append(bracket, Evals{ + id: gwrep.Workers[0].TrialId, + value: vs / float64(len(gwrep.Workers)), + }) + } else { + return nil, nil + } + + } + sort.Sort(bracket) + return nil, bracket +} + +func (h *HyperBandSuggestService) hbLoopParamUpdate(studyId string, hbparam *HyperBandParameters) { + log.Printf("HB loop s = %v", hbparam.currentS) + hbparam.shloopitr = 0 + hbparam.n = int(math.Trunc((hbparam.b_l / hbparam.r_l) * (math.Pow(hbparam.eta, float64(hbparam.currentS)) / float64(hbparam.currentS+1)))) + hbparam.r = hbparam.r_l * math.Pow(hbparam.eta, float64(-hbparam.currentS)) +} + +func (h *HyperBandSuggestService) getLoopParam(studyId string, hbparam *HyperBandParameters) (int, float64) { + log.Printf("SH loop i = %v", hbparam.shloopitr) + n_i := int(math.Trunc(float64(hbparam.n) * math.Pow(hbparam.eta, float64(-hbparam.shloopitr)))) + r_i := hbparam.r * math.Pow(hbparam.eta, float64(hbparam.shloopitr)) + return n_i, r_i +} +func (h *HyperBandSuggestService) shLoopParamUpdate(studyId string, hbparam *HyperBandParameters) { + hbparam.shloopitr++ + if hbparam.shloopitr > hbparam.currentS { + hbparam.currentS-- + } +} func (h *HyperBandSuggestService) GetSuggestions(ctx context.Context, in *api.GetSuggestionsRequest) (*api.GetSuggestionsReply, error) { - return &api.GetSuggestionsReply{}, nil - // conn, err := grpc.Dial(manager, grpc.WithInsecure()) - // if err != nil { - // log.Fatalf("could not connect: %v", err) - // return - // } - // defer conn.Close() - // c := api.NewManagerClient(conn) - // screq := &api.GetStudyRequest{ - // StudyId: in.StudyId, - // } - // scr, err := GetStudy(ctx, screq) - // if err != nil { - // log.Fatalf("GetStudyConf failed: %v", err) - // return &api.GetSuggestionsReply{}, err - // } - // spreq := &api.GetSuggestionParametersRequest{ - // StudyId: in.StudyId, - // SuggestionAlgorithm: in.SuggestionAlgorithm, - // } - // spr, err := c.GetSuggestionParameters(ctx, spreq) - // if err != nil { - // log.Fatalf("GetParameter failed: %v", err) - // return &api.GetSuggestionsReply{}, err - // } - // hp, err := h.purseSuggestionParameters(spr.SuggestionParameters) - // if err != nil { - // return &api.GetSuggestionsReply{}, err - // } - // - // if hp.currentS <= 0 { - // return &api.GetSuggestionsReply{}, nil - // } - // - // if len(in.LogWorkerIds) > 0 { - // - // var schec int - // var bid string - // for _, c := range in.CompletedTrials { - // schec = 0 - // value, _ := h.dbIf.GetTrialLogs(c.TrialId, - // &db.GetTrialLogOpts{Objective: true, Descending: true, Limit: 1}) - // if len(value) != 1 { - // log.Printf("objective value for %s not found", - // c.TrialId) - // continue - // } - // c.ObjectiveValue = value[0].Value - // for _, t := range c.Tags { - // if t.Name == "HyperBand_shi" && t.Value == strconv.Itoa(h.parameters[in.StudyId].shloopitr) { - // schec++ - // } - // if t.Name == "HyperBand_s" && t.Value == strconv.Itoa(h.parameters[in.StudyId].currentS) { - // schec++ - // } - // if t.Name == "HyperBand_BracketID" { - // bid = t.Value - // } - // } - // if schec == 2 { - // for _, b := range h.parameters[in.StudyId].MasterBracket { - // for _, t := range b.Tags { - // if t.Name == "HyperBand_BracketID" && t.Value == bid { - // b.ObjectiveValue = c.ObjectiveValue - // } - // } - // } - // } - // } - // sort.Sort(h.parameters[in.StudyId].MasterBracket) - // } - // var evalT []*api.Trial - // var r_i int - // if h.parameters.shloopitr > h.parameters.currentS { - // h.parameters.currentS-- - // h.hbLoopParamUpdate(in.StudyId) - // _, r_i = h.shLoopParamUpdate(in.StudyId) - // h.parameters.MasterBracket = h.makeMasterBracket(in.Configs, h.parameters.n) - // evalT = h.getHyperParameter(in.StudyId, in.Configs, h.parameters.n) - // h.parameters.shloopitr++ - // } else { - // var pn_i int - // pn_i, r_i = h.shLoopParamUpdate(in.StudyId) - // evalT = h.getHyperParameter(in.StudyId, in.Configs, pn_i) - // h.parameters.shloopitr++ - // } - // for i := range evalT { - // for j := range evalT[i].ParameterSet { - // if evalT[i].ParameterSet[j].Name == h.parameters.ResourceName { - // evalT[i].ParameterSet[j].Value = strconv.Itoa(r_i) - // } - // } - // evalT[i].Tags = append(evalT[i].Tags, &api.Tag{Name: "HyperBand_s", Value: strconv.Itoa(h.parameters.currentS)}) - // evalT[i].Tags = append(evalT[i].Tags, &api.Tag{Name: "HyperBand_r", Value: strconv.Itoa(r_i)}) - // evalT[i].Tags = append(evalT[i].Tags, &api.Tag{Name: "HyperBand_shi", Value: strconv.Itoa(h.parameters.shloopitr)}) - // log.Printf("Gen Trial %v", evalT[i].Tags) - // } - // return &api.GenerateTrialsReply{Trials: evalT, Completed: false}, nil + conn, err := grpc.Dial(manager, grpc.WithInsecure()) + if err != nil { + log.Fatalf("could not connect: %v", err) + return &api.GetSuggestionsReply{}, err + } + defer conn.Close() + c := api.NewManagerClient(conn) + spreq := &api.GetSuggestionParametersRequest{ + ParamId: in.ParamId, + } + spr, err := c.GetSuggestionParameters(ctx, spreq) + if err != nil { + log.Fatalf("GetParameter failed: %v", err) + return &api.GetSuggestionsReply{}, err + } + hbparam, err := h.purseSuggestionParameters(ctx, c, in.StudyId, spr.SuggestionParameters) + if err != nil { + return &api.GetSuggestionsReply{}, err + } + + if hbparam.currentS <= 0 { + return &api.GetSuggestionsReply{}, nil + } + + if hbparam.shloopitr > hbparam.currentS { + h.hbLoopParamUpdate(in.StudyId, hbparam) + } + n_i, r_i := h.getLoopParam(in.StudyId, hbparam) + tids, ts, err := h.makeBracket(ctx, c, in.StudyId, n_i, r_i, hbparam) + if err != nil { + return &api.GetSuggestionsReply{}, err + } + if tids == nil { + return &api.GetSuggestionsReply{}, status.Errorf(codes.FailedPrecondition, "Previous workers are not completed.") + } + hbparam.evaluatingTrials = tids + h.shLoopParamUpdate(in.StudyId, hbparam) + err = h.saveSuggestionParameters(ctx, c, in.StudyId, in.SuggestionAlgorithm, in.ParamId, hbparam) + return &api.GetSuggestionsReply{ + Trials: ts, + }, nil } diff --git a/scripts/build.sh b/scripts/build.sh index c81b326e37e..5f6a8c22ef3 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -31,8 +31,8 @@ docker build -t ${PREFIX}/vizier-core -f ${CMD_PREFIX}/manager/Dockerfile . echo "Building suggestion images..." docker build -t ${PREFIX}/suggestion-random -f ${CMD_PREFIX}/suggestion/random/Dockerfile . docker build -t ${PREFIX}/suggestion-grid -f ${CMD_PREFIX}/suggestion/grid/Dockerfile . -#docker build -t ${PREFIX}/suggestion-hyperband -f ${CMD_PREFIX}/suggestion/hyperband/Dockerfile . -#docker build -t ${PREFIX}/suggestion-bayesianoptimization -f ${CMD_PREFIX}/suggestion/bayesianoptimization/Dockerfile . +docker build -t ${PREFIX}/suggestion-hyperband -f ${CMD_PREFIX}/suggestion/hyperband/Dockerfile . +docker build -t ${PREFIX}/suggestion-bayesianoptimization -f ${CMD_PREFIX}/suggestion/bayesianoptimization/Dockerfile . echo "Building earlystopping images..." docker build -t ${PREFIX}/earlystopping-medianstopping -f ${CMD_PREFIX}/earlystopping/medianstopping/Dockerfile . diff --git a/scripts/deploy.sh b/scripts/deploy.sh index ef4ab4cfa98..19c615e91db 100755 --- a/scripts/deploy.sh +++ b/scripts/deploy.sh @@ -29,4 +29,5 @@ kubectl apply -f manifests/vizier/db kubectl apply -f manifests/vizier/core kubectl apply -f manifests/vizier/suggestion/random kubectl apply -f manifests/vizier/suggestion/grid +kubectl apply -f manifests/vizier/suggestion/hyperband cd - > /dev/null diff --git a/test/e2e/study-config.yml b/test/e2e/study-config.yml index e2c29d0f622..e000436fb01 100644 --- a/test/e2e/study-config.yml +++ b/test/e2e/study-config.yml @@ -21,9 +21,15 @@ parameterconfigs: - sgd - adam - ftrl + - name: epoch-num + parametertype: 2 + feasible: + min: "10" + max: "10" defaultsuggestionalgorithm: random defaultearlystoppingalgorithm: medianstopping objectivevaluename: learning-rate metrics: - num-layers - optimizer +- epoch-num diff --git a/test/e2e/suggestion-config.yml b/test/e2e/suggestion-config-grid.yml similarity index 100% rename from test/e2e/suggestion-config.yml rename to test/e2e/suggestion-config-grid.yml diff --git a/test/e2e/suggestion-config-hyb.yml b/test/e2e/suggestion-config-hyb.yml new file mode 100644 index 00000000000..6cc6171567d --- /dev/null +++ b/test/e2e/suggestion-config-hyb.yml @@ -0,0 +1,12 @@ +suggetionalgorithm: "hyperband" +suggestionparameters: + - + name: "ResourceName" + value: "epoch-num" + - + name: "eta" + value: "3" + - + name: "r_l" + value: "3" + diff --git a/test/e2e/test-client.go b/test/e2e/test-client.go index 5e8102d708b..1f200947796 100644 --- a/test/e2e/test-client.go +++ b/test/e2e/test-client.go @@ -15,6 +15,7 @@ import ( var managerAddr = flag.String("s", "127.0.0.1:6789", "Endpoint of manager default 127.0.0.1:6789") var suggestArgo = flag.String("a", "random", "Suggestion Algorithm (random, grid)") var requestnum = flag.Int("r", 2, "Request number for random Suggestions (default: 2)") +var suggestionConfFile = flag.String("c", "", "File path to suggestion config.") var studyConfig = api.StudyConfig{} var workerConfig = api.WorkerConfig{} @@ -37,35 +38,73 @@ func main() { //CreateStudy studyId := CreateStudy(c) - //GetSuggestion - getSuggestReply := getSuggestion(c, studyId) + //SetSuggestParam + paramId := setSuggestionParam(c, studyId) - checkSuggestions(getSuggestReply) + //GetSuggestion + if *suggestArgo == "hyperband" { + for true { + getSuggestReply := getSuggestion(c, studyId, paramId) + checkSuggestions(getSuggestReply) + if len(getSuggestReply.Trials) == 0 { + log.Printf("Hyperband ended") + break + } + //RunTrials + workerIds := runTrials(c, studyId, getSuggestReply) - //RunTrials - workerIds := runTrials(c, studyId, getSuggestReply) + iter := 0 - iter := 0 + for !isCompletedAllWorker(c, studyId) { + if iter > TimeOut { + log.Fatal("GetMetrics Timeout.") + } + time.Sleep(1 * time.Second) + getMetricsRequest := &api.GetMetricsRequest{ + StudyId: studyId, + WorkerIds: workerIds, + } + //GetMetrics + getMetricsReply, err := c.GetMetrics(ctx, getMetricsRequest) + if err != nil { + continue + } + //Save or Update model on ModelDB + SaveOrUpdateModel(c, getMetricsReply) + iter++ + } + checkWorkersResult(c, studyId) - for !isCompletedAllWorker(c, studyId) { - if iter > TimeOut { - log.Fatal("GetMetrics Timeout.") - } - time.Sleep(1 * time.Second) - getMetricsRequest := &api.GetMetricsRequest{ - StudyId: studyId, - WorkerIds: workerIds, } - //GetMetrics - getMetricsReply, err := c.GetMetrics(ctx, getMetricsRequest) - if err != nil { - continue + } else { + getSuggestReply := getSuggestion(c, studyId, paramId) + checkSuggestions(getSuggestReply) + + //RunTrials + workerIds := runTrials(c, studyId, getSuggestReply) + + iter := 0 + + for !isCompletedAllWorker(c, studyId) { + if iter > TimeOut { + log.Fatal("GetMetrics Timeout.") + } + time.Sleep(1 * time.Second) + getMetricsRequest := &api.GetMetricsRequest{ + StudyId: studyId, + WorkerIds: workerIds, + } + //GetMetrics + getMetricsReply, err := c.GetMetrics(ctx, getMetricsRequest) + if err != nil { + continue + } + //Save or Update model on ModelDB + SaveOrUpdateModel(c, getMetricsReply) + iter++ } - //Save or Update model on ModelDB - SaveOrUpdateModel(c, getMetricsReply) - iter++ + checkWorkersResult(c, studyId) } - checkWorkersResult(c, studyId) conn.Close() log.Println("E2E test OK!") } @@ -90,9 +129,11 @@ func readConfigs() { log.Fatalf("fail to Unmarshal yaml") } - buf, err = ioutil.ReadFile("suggestion-config.yml") - if err != nil { - log.Fatalf("fail to read suggestion-config yaml") + if *suggestionConfFile != "" { + buf, err = ioutil.ReadFile(*suggestionConfFile) + if err != nil { + log.Fatalf("fail to read suggestion-config yaml") + } } err = yaml.Unmarshal(buf, &suggestionConfig) if err != nil { @@ -122,7 +163,33 @@ func CreateStudy(c api.ManagerClient) string { return studyId } -func getSuggestion(c api.ManagerClient, studyId string) *api.GetSuggestionsReply { +func setSuggestionParam(c api.ManagerClient, studyId string) string { + ctx := context.Background() + switch *suggestArgo { + case "random": + return "" + case "grid": + suggestionConfig.StudyId = studyId + setSuggesitonParameterReply, err := c.SetSuggestionParameters(ctx, &suggestionConfig) + if err != nil { + log.Fatalf("SetConfig Error %v", err) + } + log.Printf("Grid suggestion prameter ID %s", setSuggesitonParameterReply.ParamId) + return setSuggesitonParameterReply.ParamId + case "hyperband": + suggestionConfig.StudyId = studyId + setSuggesitonParameterReply, err := c.SetSuggestionParameters(ctx, &suggestionConfig) + if err != nil { + log.Fatalf("SetConfig Error %v", err) + } + log.Printf("HyperBand suggestion prameter ID %s", setSuggesitonParameterReply.ParamId) + return setSuggesitonParameterReply.ParamId + } + return "" + +} + +func getSuggestion(c api.ManagerClient, studyId string, paramId string) *api.GetSuggestionsReply { ctx := context.Background() var getSuggestRequest *api.GetSuggestionsRequest switch *suggestArgo { @@ -135,18 +202,19 @@ func getSuggestion(c api.ManagerClient, studyId string) *api.GetSuggestionsReply } case "grid": - suggestionConfig.StudyId = studyId - setSuggesitonParameterReply, err := c.SetSuggestionParameters(ctx, &suggestionConfig) - if err != nil { - log.Fatalf("SetConfig Error %v", err) - } - log.Printf("Grid suggestion prameter ID %s", setSuggesitonParameterReply.ParamId) getSuggestRequest = &api.GetSuggestionsRequest{ StudyId: studyId, SuggestionAlgorithm: "grid", RequestNumber: 0, //RequestNumber=0 means get all grids. - ParamId: setSuggesitonParameterReply.ParamId, + ParamId: paramId, + } + case "hyperband": + getSuggestRequest = &api.GetSuggestionsRequest{ + StudyId: studyId, + SuggestionAlgorithm: "hyperband", + RequestNumber: 0, + ParamId: paramId, } } diff --git a/test/scripts/build-suggestion-hyperband.sh b/test/scripts/build-suggestion-hyperband.sh new file mode 100755 index 00000000000..15541e14039 --- /dev/null +++ b/test/scripts/build-suggestion-hyperband.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# Copyright 2018 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This shell script is used to build an image from our argo workflow + +set -o errexit +set -o nounset +set -o pipefail + +export PATH=${GOPATH}/bin:/usr/local/go/bin:${PATH} +REGISTRY="${GCP_REGISTRY}" +PROJECT="${GCP_PROJECT}" +GO_DIR=${GOPATH}/src/github.com/${REPO_OWNER}/${REPO_NAME}-suggestion-hyperband +VERSION=$(git describe --tags --always --dirty) + +echo "Activating service-account" +gcloud auth activate-service-account --key-file=${GOOGLE_APPLICATION_CREDENTIALS} + +echo "Copy source to GOPATH" +mkdir -p ${GO_DIR} +cp -r cmd ${GO_DIR}/cmd +cp -r pkg ${GO_DIR}/pkg +cp -r vendor ${GO_DIR}/vendor + +cd ${GO_DIR} + +cp cmd/suggestion/hyperband/Dockerfile . +gcloud container builds submit . --tag=${REGISTRY}/${REPO_NAME}/suggestion-hyperband:${VERSION} --project=${PROJECT} diff --git a/test/scripts/run-tests.sh b/test/scripts/run-tests.sh index f7e6433629f..72c52c98918 100755 --- a/test/scripts/run-tests.sh +++ b/test/scripts/run-tests.sh @@ -46,7 +46,7 @@ sed -i -e "s@image: katib\/vizier-core@image: ${REGISTRY}\/${REPO_NAME}\/vizier- sed -i -e "s@type: NodePort@type: ClusterIP@" -e "/nodePort: 30678/d" manifests/vizier/core/service.yaml sed -i -e "s@image: katib\/suggestion-random@image: ${REGISTRY}\/${REPO_NAME}\/suggestion-random:${VERSION}@" manifests/vizier/suggestion/random/deployment.yaml sed -i -e "s@image: katib\/suggestion-grid@image: ${REGISTRY}\/${REPO_NAME}\/suggestion-grid:${VERSION}@" manifests/vizier/suggestion/grid/deployment.yaml -#sed -i -e "s@image: katib\/suggestion-hyperband@image: ${REGISTRY}\/${REPO_NAME}\/suggestion-hyperband:${VERSION}@" manifests/vizier/suggestion/hyperband/deployment.yaml +sed -i -e "s@image: katib\/suggestion-hyperband@image: ${REGISTRY}\/${REPO_NAME}\/suggestion-hyperband:${VERSION}@" manifests/vizier/suggestion/hyperband/deployment.yaml #sed -i -e "s@image: katib\/suggestion-bayesianoptimization@image: ${REGISTRY}\/${REPO_NAME}\/suggestion-bayesianoptimization:${VERSION}@" manifests/vizier/suggestion/bayesianoptimization/deployment.yaml sed -i -e "s@image: katib\/earlystopping-medianstopping@image: ${REGISTRY}\/${REPO_NAME}\/earlystopping-medianstopping:${VERSION}@" manifests/vizier/earlystopping/medianstopping/deployment.yaml sed -i -e "s@image: katib\/katib-frontend@image: ${REGISTRY}\/${REPO_NAME}\/katib-frontend:${VERSION}@" manifests/modeldb/frontend/deployment.yaml @@ -84,4 +84,5 @@ done cp -r test ${GO_DIR}/test cd ${GO_DIR}/test/e2e go run test-client.go -a random -go run test-client.go -a grid +go run test-client.go -a grid -c suggestion-config-grid.yml +go run test-client.go -a hyperband -c suggestion-config-hyb.yml diff --git a/test/workflows/components/workflows.libsonnet b/test/workflows/components/workflows.libsonnet index bb94c8e03bc..44574bea4b4 100644 --- a/test/workflows/components/workflows.libsonnet +++ b/test/workflows/components/workflows.libsonnet @@ -218,6 +218,10 @@ name: "build-suggestion-grid", template: "build-suggestion-grid", }, + { + name: "build-suggestion-hyperband", + template: "build-suggestion-hyperband", + }, { name: "build-suggestion-bo", template: "build-suggestion-bo", @@ -327,6 +331,9 @@ $.parts(namespace, name).e2e(prow_env, bucket).buildTemplate("build-suggestion-grid", testWorkerImage, [ "test/scripts/build-suggestion-grid.sh", ]), // build-suggestion-grid + $.parts(namespace, name).e2e(prow_env, bucket).buildTemplate("build-suggestion-hyperband", testWorkerImage, [ + "test/scripts/build-suggestion-hyperband.sh", + ]), // build-suggestion-hyperband $.parts(namespace, name).e2e(prow_env, bucket).buildTemplate("build-suggestion-bo", testWorkerImage, [ "test/scripts/build-suggestion-bo.sh", ]), // build-suggestion-bo