diff --git a/object_manager.go b/object_manager.go index 65483f06..6c57c17a 100644 --- a/object_manager.go +++ b/object_manager.go @@ -98,110 +98,143 @@ type IBObjectManager interface { UpdateDhcpStatus(ref string, status bool) (Dhcp, error) } +const ( + ARecord = "A" + AaaaRecord = "AAAA" + CnameRecord = "CNAME" + MxRecord = "MX" + SrvRecord = "SRV" + TxtRecord = "TXT" + PtrRecord = "PTR" + HostRecordConst = "Host" +) + // Map of record type to its corresponding object var getRecordTypeMap = map[string]func() IBObject{ - "A": func() IBObject { + ARecord: func() IBObject { return NewEmptyRecordA() }, - "AAAA": func() IBObject { + AaaaRecord: func() IBObject { return NewEmptyRecordAAAA() }, - "CNAME": func() IBObject { + CnameRecord: func() IBObject { return NewEmptyRecordCNAME() }, - "MX": func() IBObject { + MxRecord: func() IBObject { return NewEmptyRecordMX() }, + SrvRecord: func() IBObject { + return NewEmptyRecordSRV() + }, + TxtRecord: func() IBObject { + return NewEmptyRecordTXT() + }, + PtrRecord: func() IBObject { + return NewEmptyRecordPTR() + }, + HostRecordConst: func() IBObject { + return NewEmptyHostRecord() + }, } // Map returns the object with search fields with the given record type var getObjectWithSearchFieldsMap = map[string]func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error){ - "A": func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { + ARecord: func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { var res interface{} if recordType.(*RecordA).Ref != "" { return res, nil } - - err := objMgr.connector.GetObject(NewEmptyRecordA(), "", NewQueryParams(false, sf), &res) - var newVal []RecordA - byteVal, err := json.Marshal(res) - if err != nil { - return nil, err + var recordAList []*RecordA + err := objMgr.connector.GetObject(NewEmptyRecordA(), "", NewQueryParams(false, sf), &recordAList) + if err == nil && len(recordAList) > 0 { + res = recordAList[0] } - err = json.Unmarshal(byteVal, &newVal) - if err != nil { - return nil, err + return res, err + }, + AaaaRecord: func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { + var res interface{} + if recordType.(*RecordAAAA).Ref != "" { + return res, nil } - if newVal == nil || len(newVal) == 0 { - return nil, NewNotFoundError("record not found") + var recordAaaList []*RecordAAAA + err := objMgr.connector.GetObject(NewEmptyRecordAAAA(), "", NewQueryParams(false, sf), &recordAaaList) + if err == nil && len(recordAaaList) > 0 { + res = recordAaaList[0] } - res = newVal return res, err }, - "AAAA": func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { + CnameRecord: func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { var res interface{} - if recordType.(*RecordAAAA).Ref != "" { + if recordType.(*RecordCNAME).Ref != "" { return res, nil } - - err := objMgr.connector.GetObject(NewEmptyRecordAAAA(), "", NewQueryParams(false, sf), &res) - var newVal []RecordAAAA - byteVal, err := json.Marshal(res) - if err != nil { - return nil, err + var cNameList []*RecordCNAME + err := objMgr.connector.GetObject(NewEmptyRecordCNAME(), "", NewQueryParams(false, sf), &cNameList) + if err == nil && len(cNameList) > 0 { + res = cNameList[0] } - err = json.Unmarshal(byteVal, &newVal) - if err != nil { - return nil, err + return res, err + }, + MxRecord: func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { + var res interface{} + if recordType.(*RecordMX).Ref != "" { + return res, nil } - if newVal == nil || len(newVal) == 0 { - return nil, NewNotFoundError("record not found") + var mxList []*RecordMX + err := objMgr.connector.GetObject(NewEmptyRecordMX(), "", NewQueryParams(false, sf), &mxList) + if err == nil && len(mxList) > 0 { + res = mxList[0] } - res = newVal return res, err + }, - "CNAME": func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { + SrvRecord: func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { var res interface{} - if recordType.(*RecordCNAME).Ref != "" { + if recordType.(*RecordSRV).Ref != "" { return res, nil } - err := objMgr.connector.GetObject(NewEmptyRecordCNAME(), "", NewQueryParams(false, sf), &res) - var newVal []RecordCNAME - byteVal, err := json.Marshal(res) - if err != nil { - return nil, err + var srvList []*RecordSRV + err := objMgr.connector.GetObject(NewEmptyRecordSRV(), "", NewQueryParams(false, sf), &srvList) + if err == nil && len(srvList) > 0 { + res = srvList[0] } - err = json.Unmarshal(byteVal, &newVal) - if err != nil { - return nil, err + return res, err + }, + TxtRecord: func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { + var res interface{} + if recordType.(*RecordTXT).Ref != "" { + return res, nil } - if newVal == nil || len(newVal) == 0 { - return nil, NewNotFoundError("record not found") + var txtList []*RecordTXT + err := objMgr.connector.GetObject(NewEmptyRecordTXT(), "", NewQueryParams(false, sf), &txtList) + if err == nil && len(txtList) > 0 { + res = txtList[0] } - res = newVal return res, err }, - "MX": func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { + PtrRecord: func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { var res interface{} - if recordType.(*RecordMX).Ref != "" { + if recordType.(*RecordPTR).Ref != "" { return res, nil } - err := objMgr.connector.GetObject(NewEmptyRecordMX(), "", NewQueryParams(false, sf), &res) - var newVal []RecordMX - byteVal, err := json.Marshal(res) - if err != nil { - return nil, err + var ptrList []*RecordPTR + err := objMgr.connector.GetObject(NewEmptyRecordPTR(), "", NewQueryParams(false, sf), &ptrList) + if err == nil && len(ptrList) > 0 { + res = ptrList[0] } - err = json.Unmarshal(byteVal, &newVal) - if err != nil { - return nil, err + return res, err + }, + HostRecordConst: func(recordType IBObject, objMgr *ObjectManager, sf map[string]string) (interface{}, error) { + var res interface{} + if recordType.(*HostRecord).Ref != "" { + return res, nil } - if newVal == nil || len(newVal) == 0 { - return nil, NewNotFoundError("record not found") + var hostRecordList []*HostRecord + err := objMgr.connector.GetObject(NewEmptyHostRecord(), "", NewQueryParams(false, sf), &hostRecordList) + if err == nil && len(hostRecordList) > 0 { + res = hostRecordList[0] } - res = newVal return res, err - }, } @@ -430,12 +463,11 @@ func (objMgr *ObjectManager) SearchDnsObjectByAltId( recordType IBObject res interface{} ) - - if getRecordTypeMap[objType] != nil { - recordType = getRecordTypeMap[objType]() - } else { + val, ok := getRecordTypeMap[objType] + if !ok { return nil, fmt.Errorf("unknown record type") } + recordType = val() if ref != "" { if err := objMgr.connector.GetObject(recordType, ref, NewQueryParams(false, nil), &res); err != nil { @@ -457,20 +489,17 @@ func (objMgr *ObjectManager) SearchDnsObjectByAltId( } // Fetch the object by search fields - if getObjectWithSearchFieldsMap[objType] != nil { - res, err = getObjectWithSearchFieldsMap[objType](recordType, objMgr, sf) - } else { + getObjectWithSearchFields, ok := getObjectWithSearchFieldsMap[objType] + if !ok { return nil, fmt.Errorf("unknown record type") } - + res, err = getObjectWithSearchFields(recordType, objMgr, sf) if err != nil { return nil, err } - if res == nil { return nil, NewNotFoundError("record not found") } - result := res - return &result, nil + return &res, nil }