diff --git a/pkg/sr/api.go b/pkg/sr/api.go new file mode 100644 index 00000000..7414b5e1 --- /dev/null +++ b/pkg/sr/api.go @@ -0,0 +1,616 @@ +package sr + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "sort" + "sync" + "sync/atomic" +) + +// This file is an implementation of: +// +// https://docs.confluent.io/platform/current/schema-registry/develop/api.html +// + +// SupportedTypes returns the schema types that are supported in the schema +// registry. +func (cl *Client) SupportedTypes(ctx context.Context) ([]SchemaType, error) { + // GET /schemas/types + var types []SchemaType + return types, cl.get(ctx, "/schemas/types", &types) +} + +// SchemaReference is a way for a one schema to reference another. The details +// for how referencing is done are type specific; for example, JSON objects +// that use the key "$ref" can refer to another schema via URL. For more details +// on references, see the following link: +// +// https://docs.confluent.io/platform/current/schema-registry/serdes-develop/index.html#schema-references +// https://docs.confluent.io/platform/current/schema-registry/develop/api.html +// +type SchemaReference struct { + Name string `json:"name"` + Subject string `json:"subject"` + Version int `json:"version"` +} + +// Schema is the object form of a schema for the HTTP API. +type Schema struct { + // Schema is the actual unescaped text of a schema. + Schema string `json:"schema"` + + // Type is the type of a schema. The default type is avro. + Type SchemaType `json:"schemaType,omitempty"` + + // References declares other schemas this schema references. See the + // docs on SchemaReference for more details. + References []SchemaReference `json:"references,omitempty"` +} + +type rawSchema struct { + Schema string `json:"schema"` + Type SchemaType `json:"schemaType,omitempty"` + References []SchemaReference `json:"references,omitempty"` +} + +func (s *Schema) UnmarshalJSON(b []byte) error { + var raw rawSchema + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + + *s = Schema(raw) + if err := json.Unmarshal([]byte(raw.Schema), &s.Schema); err != nil { + return err + } + return nil +} + +func (s Schema) MarshalJSON() ([]byte, error) { + b, err := json.Marshal(s.Schema) + if err != nil { + return nil, err + } + raw := rawSchema(s) + raw.Schema = string(b) + return json.Marshal(raw) +} + +// SubjectSchema pairs the subject, global identifier, and version of a schema +// with the schema itself. +type SubjectSchema struct { + // Subject is the subject for this schema. This usually corresponds to + // a Kafka topic, and whether this is for a key or value. For example, + // "foo-key" would be the subject for the foo topic for serializing the + // key field of a record. + Subject string `json:"subject"` + + // Version is the version of this subject. + Version int `json:"version"` + + // ID is the globally unique ID of the schema. + ID int `json:"id"` + + Schema +} + +// Subjects returns all alive and soft-deleted subjects available in the +// registry. +func (cl *Client) Subjects(ctx context.Context) (alive, softDeleted []string, err error) { + // GET /subjects?deleted={x} + if err = cl.get(ctx, "/subjects", &alive); err != nil { + return nil, nil, err + } + var all []string + if err = cl.get(ctx, "/subjects?deleted=true", &all); err != nil { + return nil, nil, err + } + mdeleted := make(map[string]struct{}, len(all)) + for _, subject := range all { + mdeleted[subject] = struct{}{} + } + for _, subject := range alive { + delete(mdeleted, subject) + } + for subject := range mdeleted { + softDeleted = append(softDeleted, subject) + } + sort.Strings(alive) + sort.Strings(softDeleted) + return alive, softDeleted, nil +} + +// SchemaTextByID returns the actual text of a schema. +// +// For example, if the schema for an ID is +// +// "{\"type\":\"boolean\"}" +// +// this will return +// +// {"type":"boolean"} +// +func (cl *Client) SchemaTextByID(ctx context.Context, id int) (string, error) { + // GET /schemas/ids/{id} + var s Schema + if err := cl.get(ctx, fmt.Sprintf("/schemas/ids/%d", id), &s); err != nil { + return "", err + } + return s.Schema, nil +} + +func pathSubject(subject string) string { return fmt.Sprintf("/subjects/%s", url.PathEscape(subject)) } +func pathSubjectWithVersion(subject string) string { return pathSubject(subject) + "/versions" } +func pathSubjectVersion(subject string, version int) string { + if version == -1 { + return pathSubjectWithVersion(subject) + "/latest" + } + return fmt.Sprintf("%s/%d", pathSubjectWithVersion(subject), version) +} + +func pathConfig(subject string) string { + if subject == "" { + return "/config" + } + return fmt.Sprintf("/config/%s", url.PathEscape(subject)) +} + +func pathMode(subject string) string { + if subject == "" { + return "/mode" + } + return fmt.Sprintf("/mode/%s", url.PathEscape(subject)) +} + +// SchemaByVersion returns the schema for a given subject and version. You can +// use -1 as the version to return the latest schema. +func (cl *Client) SchemaByVersion(ctx context.Context, subject string, version int) (SubjectSchema, error) { + // GET /subjects/{subject}/versions/{version} + var ss SubjectSchema + return ss, cl.get(ctx, pathSubjectVersion(subject, version), &ss) +} + +// Schemas returns all schemas for the given subject. +func (cl *Client) Schemas(ctx context.Context, subject string) ([]SubjectSchema, error) { + // GET /subjects/{subject}/versions => []int (versions) + var versions []int + if err := cl.get(ctx, pathSubjectWithVersion(subject), &versions); err != nil { + return nil, err + } + sort.Ints(versions) + + var ( + schemas = make([]SubjectSchema, len(versions)) + firstErr error + errOnce uint32 + wg sync.WaitGroup + cctx, cancel = context.WithCancel(ctx) + ) + for i := range versions { + version := versions[i] + slot := i + wg.Add(1) + go func() { + defer wg.Done() + s, err := cl.SchemaByVersion(cctx, subject, version) + schemas[slot] = s + if err != nil && atomic.SwapUint32(&errOnce, 1) == 0 { + firstErr = err + cancel() + } + }() + } + wg.Wait() + + return schemas, firstErr +} + +// CreateSchema attempts to create a schema in the given subject. +func (cl *Client) CreateSchema(ctx context.Context, subject string, s Schema) (SubjectSchema, error) { + // POST /subjects/{subject}/versions => returns ID + path := pathSubjectWithVersion(subject) + if cl.normalize { + path += "?normalize=true" + } + var id int + if err := cl.post(ctx, path, s, &id); err != nil { + return SubjectSchema{}, err + } + + usages, err := cl.SchemaUsagesByID(ctx, id) + if err != nil { + return SubjectSchema{}, err + } + for _, usage := range usages { + if usage.Subject == subject { + return usage, nil + } + } + return SubjectSchema{}, fmt.Errorf("created schema under id %d, but unable to find SubjectSchema") +} + +// LookupSchema checks to see if a schema is already registered and if so, +// returns its ID and version in the SubjectSchema. +func (cl *Client) LookupSchema(ctx context.Context, subject string, s Schema) (SubjectSchema, error) { + // POST /subjects/{subject}/ + path := pathSubject(subject) + if cl.normalize { + path += "?normalize=true" + } + var ss SubjectSchema + return ss, cl.post(ctx, path, s, &ss) +} + +// DeleteHow is a typed bool indicating how subjects or schemas should be +// deleted. +type DeleteHow bool + +const ( + // SoftDelete performs a soft deletion. + SoftDelete = false + // HardDelete performs a hard deletion. Values must be soft deleted + // before they can be hard deleted. + HardDelete = true +) + +// DeleteSubjects deletes the subject. You must soft delete a subject before it +// can be hard deleted. This returns all versions that were deleted. +func (cl *Client) DeleteSubject(ctx context.Context, how DeleteHow, subject string) ([]int, error) { + // DELETE /subjects/{subject}?permanent={x} + path := pathSubject(subject) + if how == HardDelete { + path += "?permanent=true" + } + var versions []int + defer func() { sort.Ints(versions) }() + return versions, cl.delete(ctx, path, &versions) +} + +// DeleteSubjects deletes the schema at the given version. You must soft delete +// a schema before it can be hard deleted. You can use -1 to delete the latest +// version. +func (cl *Client) DeleteSchema(ctx context.Context, how DeleteHow, subject string, version int) error { + // DELETE /subjects/{subject}/versions/{version}?permanent={x} + path := pathSubjectVersion(subject, version) + if how == HardDelete { + path += "?permanent=true" + } + return cl.delete(ctx, path, nil) +} + +// SchemaReferences returns all schemas that references the input +// subject-version. You can use -1 to check the latest version. +func (cl *Client) SchemaReferences(ctx context.Context, subject string, version int) ([]SubjectSchema, error) { + // GET /subjects/{subject}/versions/{version}/referencedby + // SchemaUsagesByID + var ids []int + if err := cl.get(ctx, pathSubjectVersion(subject, version)+"/referencedby", &ids); err != nil { + return nil, err + } + + var ( + schemas []SubjectSchema + firstErr error + mu sync.Mutex + wg sync.WaitGroup + cctx, cancel = context.WithCancel(ctx) + ) + for i := range ids { + id := ids[i] + wg.Add(1) + go func() { + defer wg.Done() + idSchemas, err := cl.SchemaUsagesByID(cctx, id) + mu.Lock() + defer mu.Unlock() + schemas = append(schemas, idSchemas...) + if err != nil && firstErr == nil { + firstErr = err + cancel() + } + }() + } + wg.Wait() + + return schemas, firstErr +} + +// SchemaUsagesByID returns all usages of a given schema ID. A single schema's +// can be reused in many subject-versions; this function can be used to map a +// schema to all subject-versions that use it. +func (cl *Client) SchemaUsagesByID(ctx context.Context, id int) ([]SubjectSchema, error) { + // GET /schemas/ids/{id}/versions + // SchemaByVersion + type subjectVersion struct { + Subject string `json:"subject"` + Version int `json:"version"` + } + var subjectVersions []subjectVersion + if err := cl.get(ctx, fmt.Sprintf("/schemas/ids/%d/versions", id), &subjectVersions); err != nil { + return nil, err + } + + var ( + schemas = make([]SubjectSchema, len(subjectVersions)) + firstErr error + errOnce uint32 + wg sync.WaitGroup + cctx, cancel = context.WithCancel(ctx) + ) + for i := range subjectVersions { + sv := subjectVersions[i] + slot := i + wg.Add(1) + go func() { + defer wg.Done() + s, err := cl.SchemaByVersion(cctx, sv.Subject, sv.Version) + schemas[slot] = s + if err != nil && atomic.SwapUint32(&errOnce, 1) == 0 { + firstErr = err + cancel() + } + }() + } + wg.Wait() + + if firstErr != nil { + return nil, firstErr + } + + type ssi struct { + subject string + version int + id int + } + + uniq := make(map[ssi]SubjectSchema) + for _, s := range schemas { + uniq[ssi{ + subject: s.Subject, + version: s.Version, + id: s.ID, + }] = s + } + schemas = nil + for _, s := range uniq { + schemas = append(schemas, s) + } + return schemas, nil +} + +// GlobalSubject is a constant to make API usage of requesting global subjects +// clearer. +const GlobalSubject = "" + +// CompatibilityResult is the compatibility level for a subject. +type CompatibilityResult struct { + Subject string // The subject this compatbility result is for, or empty for the global level. + Level CompatibilityLevel // The subject (or global) compatibilty level. + Err error // The error received for getting this compatibility level. +} + +// CompatibilityLevel returns the subject level and global level compatibility +// of each requested subject. The global level can be requested by using either +// an empty subject or by specifying no subjects. +func (cl *Client) CompatibilityLevel(ctx context.Context, subjects ...string) []CompatibilityResult { + // GET /config/{subject} + // GET /config + if len(subjects) == 0 { + subjects = append(subjects, GlobalSubject) + } + var ( + wg sync.WaitGroup + results = make([]CompatibilityResult, len(subjects)) + ) + for i := range subjects { + subject := subjects[i] + slot := i + wg.Add(1) + go func() { + defer wg.Done() + var c struct { + Level CompatibilityLevel `json:"compatibilityLevel"` + } + err := cl.get(ctx, pathConfig(subject), &c) + results[slot] = CompatibilityResult{ + Subject: subject, + Level: c.Level, + Err: err, + } + }() + } + wg.Wait() + + return results +} + +// SetCompatibilityLevel sets the compatibility level for each requested +// subject. The global level can be set by either using an empty subject or by +// specifying no subjects. If specifying no subjects, this returns one element. +func (cl *Client) SetCompatibilityLevel(ctx context.Context, level CompatibilityLevel, subjects ...string) []CompatibilityResult { + // PUT /config/{subject} + // PUT /config + if len(subjects) == 0 { + subjects = append(subjects, GlobalSubject) + } + var ( + wg sync.WaitGroup + results = make([]CompatibilityResult, len(subjects)) + ) + for i := range subjects { + subject := subjects[i] + slot := i + wg.Add(1) + go func() { + defer wg.Done() + c := struct { + Level CompatibilityLevel `json:"compatibility"` + }{level} + err := cl.put(ctx, pathConfig(subject), c, &c) + results[slot] = CompatibilityResult{ + Subject: subject, + Level: c.Level, + Err: err, + } + }() + } + wg.Wait() + + return results +} + +// ResetCompatibilityLevel deletes any subject-level compatibility level and +// reverts to the global default. +func (cl *Client) ResetCompatibilityLevel(ctx context.Context, subjects ...string) []CompatibilityResult { + // DELETE /config/{subject} + if len(subjects) == 0 { + return nil + } + var ( + wg sync.WaitGroup + results = make([]CompatibilityResult, len(subjects)) + ) + for i := range subjects { + subject := subjects[i] + slot := i + wg.Add(1) + go func() { + defer wg.Done() + var c struct { + Level CompatibilityLevel `json:"compatibility"` + } + err := cl.delete(ctx, pathConfig(subject), &c) + results[slot] = CompatibilityResult{ + Subject: subject, + Level: c.Level, + Err: err, + } + }() + } + wg.Wait() + + return results +} + +// CheckCompatibility checks if a schema is compatible with the given version +// that exists. You can use -1 to check compatibility with all versions. +func (cl *Client) CheckCompatibility(ctx context.Context, subject string, version int, s Schema) (bool, error) { + // POST /compatibility/subjects/{subject}/versions/{version}?reason=true + // POST /compatibility/subjects/{subject}/versions?reason=true + path := pathSubjectVersion(subject, version) + if version == -1 { + path = pathSubjectWithVersion(subject) + } + var is bool + return is, cl.post(ctx, path, s, &is) +} + +// ModeResult is the mode for a subject. +type ModeResult struct { + Subject string // The subject this mode result is for, or empty for the global mode. + Mode Mode // The subject (or global) mode. + Err error // The error received for getting this mode. +} + +type modeResponse struct { + Mode Mode `json:"mode"` +} + +// Mode returns the subject and global mode of each requested subject. The +// global mode can be requested by using either an empty subject or by +// specifying no subjects. +func (cl *Client) Mode(ctx context.Context, subjects ...string) []ModeResult { + // GET /mode/{subject} + // GET /mode + if len(subjects) == 0 { + subjects = append(subjects, GlobalSubject) + } + var ( + wg sync.WaitGroup + results = make([]ModeResult, len(subjects)) + ) + for i := range subjects { + subject := subjects[i] + slot := i + wg.Add(1) + go func() { + defer wg.Done() + var m modeResponse + err := cl.get(ctx, pathMode(subject), &m) + results[slot] = ModeResult{ + Subject: subject, + Mode: m.Mode, + Err: err, + } + }() + } + wg.Wait() + + return results +} + +// SetMode sets the mode for each requested subject. The global mode can be set +// by either using an empty subject or by specifying no subjects. If specifying +// no subjects, this returns one element. +func (cl *Client) SetMode(ctx context.Context, mode Mode, subjects ...string) []ModeResult { + // PUT /mode/{subject} + // PUT /mode + if len(subjects) == 0 { + subjects = append(subjects, GlobalSubject) + } + var ( + wg sync.WaitGroup + results = make([]ModeResult, len(subjects)) + ) + for i := range subjects { + subject := subjects[i] + slot := i + wg.Add(1) + go func() { + defer wg.Done() + var m modeResponse + err := cl.put(ctx, pathMode(subject), m, &m) + results[slot] = ModeResult{ + Subject: subject, + Mode: m.Mode, + Err: err, + } + }() + } + wg.Wait() + + return results +} + +// ResetMode deletes any subject modes and reverts to the global default. +func (cl *Client) ResetMode(ctx context.Context, subjects ...string) []ModeResult { + // DELETE /mode/{subject} + if len(subjects) == 0 { + return nil + } + var ( + wg sync.WaitGroup + results = make([]ModeResult, len(subjects)) + ) + for i := range subjects { + subject := subjects[i] + slot := i + wg.Add(1) + go func() { + defer wg.Done() + var m modeResponse + err := cl.delete(ctx, pathMode(subject), &m) + results[slot] = ModeResult{ + Subject: subject, + Mode: m.Mode, + Err: err, + } + }() + } + wg.Wait() + + return results +} diff --git a/pkg/sr/client.go b/pkg/sr/client.go new file mode 100644 index 00000000..8cc6f4bc --- /dev/null +++ b/pkg/sr/client.go @@ -0,0 +1,151 @@ +// Package sr provides a schema registry client and a helper type to encode +// values and decode data according to the schema registry wire format. +// +// As mentioned on the Serde type, this package does not provide schema +// auto-discovery and type auto-decoding. To aid in strong typing and validated +// encoding/decoding, you must register IDs and values to how to encode or +// decode them. +// +// The client does not automatically cache schemas, instead, the Serde type is +// used for the actual caching of IDs to how to encode/decode the IDs. The +// Client type itself simply speaks http to your schema registry and returns +// the results. +// +// To read more about the schema registry, see the following: +// +// https://docs.confluent.io/platform/current/schema-registry/develop/api.html +// +package sr + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sync/atomic" + "time" +) + +// ResponseError is the type returned from the schema registry for errors. +type ResponseError struct { + // Method is the requested http method. + Method string `json:"-"` + // URL is the full path that was requested that resulted in this error. + URL string `json:"-"` + + ErrorCode int `json:"error_code"` + Message string `json:"message"` +} + +func (e *ResponseError) Error() string { return e.Message } + +// Client talks to a schema registry and contains helper functions to serialize +// and deserialize objects according to schemas. +type Client struct { + urls []string + httpcl *http.Client + + basicAuth *struct { + user string + pass string + } + + normalize bool + + serdes atomic.Value // map[reflect.Type]serde +} + +// NewClient returns a new schema registry client. +func NewClient(opts ...Opt) (*Client, error) { + cl := &Client{ + urls: []string{"http://localhost:8081"}, + httpcl: &http.Client{Timeout: 5 * time.Second}, + } + + for _, opt := range opts { + opt.apply(cl) + } + + if len(cl.urls) == 0 { + return nil, errors.New("unable to create client with no URLs") + } + + return cl, nil +} + +func (cl *Client) get(ctx context.Context, path string, into interface{}) error { + return cl.do(ctx, http.MethodGet, path, nil, into) +} + +func (cl *Client) post(ctx context.Context, path string, v interface{}, into interface{}) error { + return cl.do(ctx, http.MethodPost, path, v, into) +} + +func (cl *Client) put(ctx context.Context, path string, v interface{}, into interface{}) error { + return cl.do(ctx, http.MethodPut, path, v, into) +} + +func (cl *Client) delete(ctx context.Context, path string, into interface{}) error { + return cl.do(ctx, http.MethodDelete, path, nil, into) +} + +func (cl *Client) do(ctx context.Context, method, path string, v interface{}, into interface{}) error { + urls := cl.urls + +start: + url := fmt.Sprintf("%s%s", urls[0], path) + urls = urls[1:] + + var reqBody io.Reader + if v != nil { + marshaled, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("unable to encode body for %s %q: %w", method, url, err) + } + reqBody = bytes.NewReader(marshaled) + } + + req, err := http.NewRequestWithContext(ctx, method, url, reqBody) + if err != nil { + return fmt.Errorf("unable to create request for %s %q: %v", method, url, err) + } + req.Header.Set("Content-Type", "application/vnd.schemaregistry.v1+json") + if cl.basicAuth != nil { + req.SetBasicAuth(cl.basicAuth.user, cl.basicAuth.pass) + } + + resp, err := cl.httpcl.Do(req) + if err != nil { + if len(urls) == 0 { + return fmt.Errorf("unable to %s %q: %w", method, url, err) + } + goto start + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return fmt.Errorf("unable to read response body from %s %q: %w", method, url, err) + } + + if resp.StatusCode != 200 { + e := &ResponseError{ + Method: method, + URL: url, + } + if err := json.Unmarshal(body, e); err != nil { + return fmt.Errorf("unable to decode erroring response body from %s %q: %w", method, url, err) + } + return e + } + + if into != nil { + if err := json.Unmarshal(body, into); err != nil { + return fmt.Errorf("unable to decode ok response body from %s %q: %w", method, url, err) + } + } + return nil +} diff --git a/pkg/sr/config.go b/pkg/sr/config.go new file mode 100644 index 00000000..7ff04f06 --- /dev/null +++ b/pkg/sr/config.go @@ -0,0 +1,78 @@ +package sr + +import ( + "crypto/tls" + "net" + "net/http" + "strings" + "time" +) + +type ( + // Opt is an option to configure a client. + Opt interface{ apply(*Client) } + opt struct{ fn func(*Client) } +) + +func (o opt) apply(cl *Client) { o.fn(cl) } + +// HTTPClient sets the http client that the schema registry client uses, +// overriding the default client that speaks plaintext with a timeout of 5s. +func HTTPClient(httpcl *http.Client) Opt { + return opt{func(cl *Client) { cl.httpcl = httpcl }} +} + +// URLs sets the URLs that the client speaks to, overriding the default +// http://localhost:8081. This option automatically prefixes any URL that is +// missing an http:// or https:// prefix with http://. +func URLs(urls ...string) Opt { + return opt{func(cl *Client) { + for i, u := range urls { + if !strings.HasPrefix(u, "http://") { + urls[i] = "http://" + u + } + } + cl.urls = urls + }} +} + +// DialTLSConfig sets a tls.Config to use in a the default http client. +func DialTLSConfig(c *tls.Config) Opt { + return opt{func(cl *Client) { + cl.httpcl = &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSClientConfig: c, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } + }} +} + +// Normalize sets the client to add the "?normalize=true" query parameter when +// getting or creating schemas. This can help collapse duplicate schemas into +// one, but can also be done with a configuration parameter on the schema +// registry itself. +func Normalize() Opt { + return opt{func(cl *Client) { cl.normalize = true }} +} + +// BasicAuth sets basic authorization to use for every request. +func BasicAuth(user, pass string) Opt { + return opt{func(cl *Client) { + cl.basicAuth = &struct { + user string + pass string + }{user, pass} + }} +} diff --git a/pkg/sr/enums.go b/pkg/sr/enums.go new file mode 100644 index 00000000..66bbb03e --- /dev/null +++ b/pkg/sr/enums.go @@ -0,0 +1,156 @@ +package sr + +import "fmt" + +// SchemaType as an enum representing schema types. The default schema type +// is avro. +type SchemaType int + +const ( + TypeAvro SchemaType = iota + TypeProtobuf + TypeJSON +) + +func (t SchemaType) String() string { + switch t { + case TypeAvro: + return "AVRO" + case TypeProtobuf: + return "PROTOBUF" + case TypeJSON: + return "JSON" + default: + return "" + } +} + +func (t SchemaType) MarshalText() ([]byte, error) { + s := t.String() + if s == "" { + return nil, fmt.Errorf("unknown schema type %d", t) + } + return []byte(s), nil +} + +func (t *SchemaType) UnmrshalText(text []byte) error { + switch s := string(text); s { + default: + return fmt.Errorf("unknown schema type %q", s) + case "", "AVRO": + *t = TypeAvro + case "PROTOBUF": + *t = TypeProtobuf + case "JSON": + *t = TypeJSON + } + return nil +} + +// CompatibilityLevel as an enum representing config compatibility levels. +type CompatibilityLevel int + +const ( + CompatNone CompatibilityLevel = 1 + iota + CompatBackward + CompatBackwardTransitive + CompatForward + CompatForwardTransitive + CompatFull + CompatFullTransitive +) + +func (l CompatibilityLevel) String() string { + switch l { + case CompatNone: + return "NONE" + case CompatBackward: + return "BACKWARD" + case CompatBackwardTransitive: + return "BACKWARD_TRANSITIVE" + case CompatForward: + return "FORWARD" + case CompatForwardTransitive: + return "FORWARD_TRANSITIVE" + case CompatFull: + return "FULL" + case CompatFullTransitive: + return "FULL_TRANSITIVE" + default: + return "" + } +} + +func (l CompatibilityLevel) MarshalText() ([]byte, error) { + s := l.String() + if s == "" { + return nil, fmt.Errorf("unknown compatibility level %d", l) + } + return []byte(s), nil +} + +func (l *CompatibilityLevel) UnmrshalText(text []byte) error { + switch s := string(text); s { + default: + return fmt.Errorf("unknown compatibility level %q", s) + case "NONE": + *l = CompatNone + case "BACKWARD": + *l = CompatBackward + case "BACKWARD_TRANSITIVE": + *l = CompatBackwardTransitive + case "FORWARD": + *l = CompatForward + case "FORWARD_TRANSITIVE": + *l = CompatForwardTransitive + case "FULL": + *l = CompatFull + case "FULL_TRANSITIVE": + *l = CompatFullTransitive + } + return nil +} + +// Mode as an enum representing the "mode" of the registry or a subject. +type Mode int + +const ( + ModeImport Mode = iota + ModeReadOnly + ModeReadWrite +) + +func (m Mode) String() string { + switch m { + case ModeImport: + return "IMPORT" + case ModeReadOnly: + return "READONLY" + case ModeReadWrite: + return "READWRITE" + default: + return "" + } +} + +func (m Mode) MarshalText() ([]byte, error) { + s := m.String() + if s == "" { + return nil, fmt.Errorf("unknown mode %d", m) + } + return []byte(s), nil +} + +func (m *Mode) UnmrshalText(text []byte) error { + switch s := string(text); s { + default: + return fmt.Errorf("unknown schema type %q", s) + case "IMPORT": + *m = ModeImport + case "READONLY": + *m = ModeReadOnly + case "READWRITE": + *m = ModeReadWrite + } + return nil +} diff --git a/pkg/sr/go.mod b/pkg/sr/go.mod new file mode 100644 index 00000000..7496fd7f --- /dev/null +++ b/pkg/sr/go.mod @@ -0,0 +1,3 @@ +module github.com/twmb/franz-go/pkg/sr + +go 1.15 diff --git a/pkg/sr/serde.go b/pkg/sr/serde.go new file mode 100644 index 00000000..bf94bb88 --- /dev/null +++ b/pkg/sr/serde.go @@ -0,0 +1,222 @@ +package sr + +import ( + "encoding/binary" + "errors" + "reflect" + "sync" + "sync/atomic" +) + +// The wire format for encoded types is 0, then big endian uint32 of the ID, +// then the encoded message. +// +// https://docs.confluent.io/platform/current/schema-registry/serdes-develop/index.html#wire-format + +var ( + // ErrNotRegistered is returned from Serde when attempting to encode a + // value or decode an ID that has not been registered, or when using + // Decode with a missing new value function. + ErrNotRegistered = errors.New("registration is missing for encode/decode") + + // ErrBadHeader is returned from Decode when the input slice is shorter + // than five bytes, or if the first byte is not the magic 0 byte. + ErrBadHeader = errors.New("5 byte header for value is missing or does no have 0 magic byte") +) + +type ( + // SerdeOpt is an option to configure a Serde. + SerdeOpt interface{ apply(*tserde) } + serdeOpt struct{ fn func(*tserde) } +) + +func (o serdeOpt) apply(t *tserde) { o.fn(t) } + +// EncodeFn allows Serde to encode a value. +func EncodeFn(fn func(interface{}) ([]byte, error)) SerdeOpt { + return serdeOpt{func(t *tserde) { t.encode = fn }} +} + +// AppendEncodeFn allows Serde to encode a value to an existing slice. This +// can be more efficient than EncodeFn; this function is used if it exists. +func AppendEncodeFn(fn func([]byte, interface{}) ([]byte, error)) SerdeOpt { + return serdeOpt{func(t *tserde) { t.appendEncode = fn }} +} + +// DecodeFn allows Serde to decode into a value. +func DecodeFn(fn func([]byte, interface{}) error) SerdeOpt { + return serdeOpt{func(t *tserde) { t.decode = fn }} +} + +// NewValueFn allows Serde to generate a new value to decode into, allowing the use +// of the Decode and MustDecode functions. +func NewValueFn(fn func() interface{}) SerdeOpt { + return serdeOpt{func(t *tserde) { t.mk = fn }} +} + +type tserde struct { + id uint32 + encode func(interface{}) ([]byte, error) + appendEncode func([]byte, interface{}) ([]byte, error) + decode func([]byte, interface{}) error + mk func() interface{} +} + +// Serde encodes and decodes values according to the schema registry wire +// format. A Serde itself does not perform schema auto-discovery and type +// auto-decoding. To aid in strong typing and validated encoding/decoding, you +// must register IDs and values to how to encode or decode them. +// +// To use a Serde for encoding, you must pre-register schema ids and values you +// will encode, and then you can use the encode functions. +// +// To use a Serde for decoding, you can either pre-register schema ids and +// values you will consume, or you can discover the schema every time you +// receive an ErrNotRegistered error from decode. +type Serde struct { + ids atomic.Value // map[int]tserde + types atomic.Value // map[reflect.Type]tserde + mu sync.Mutex +} + +var ( + noIDs = make(map[int]tserde) + noTypes = make(map[reflect.Type]tserde) +) + +func (s *Serde) loadIDs() map[int]tserde { + ids := s.ids.Load() + if ids == nil { + return noIDs + } + return ids.(map[int]tserde) +} + +func (s *Serde) loadTypes() map[reflect.Type]tserde { + types := s.types.Load() + if types == nil { + return noTypes + } + return types.(map[reflect.Type]tserde) +} + +// Register registers a schema ID and the value it corresponds to, as well as +// the encoding or decoding functions. You need to register functions depending +// on whether you are only encoding, only decoding, or both. +func (s *Serde) Register(id int, v interface{}, opts ...SerdeOpt) { + t := tserde{id: uint32(id)} + for _, opt := range opts { + opt.apply(&t) + } + + s.mu.Lock() + defer s.mu.Unlock() + + { + dup := make(map[int]tserde) + for k, v := range s.loadIDs() { + dup[k] = v + } + dup[id] = t + s.ids.Store(dup) + } + + { + dup := make(map[reflect.Type]tserde) + for k, v := range s.loadTypes() { + dup[k] = v + } + dup[reflect.TypeOf(v)] = t + s.types.Store(dup) + } +} + +// Encode encodes a value according to the schema registry wire format and +// returns it. If EncodeFn was not used, this returns ErrNotRegistered. +func (s *Serde) Encode(v interface{}) ([]byte, error) { + return s.AppendEncode(nil, v) +} + +// AppendEncode appends an encoded value to b according to the schema registry +// wire format and returns it. If EncodeFn was not used, this returns +// ErrNotRegistered. +func (s *Serde) AppendEncode(b []byte, v interface{}) ([]byte, error) { + t, ok := s.loadTypes()[reflect.TypeOf(v)] + if !ok || (t.encode == nil && t.appendEncode == nil) { + return b, ErrNotRegistered + } + + b = append(b, + 0, + byte(t.id>>24), + byte(t.id>>16), + byte(t.id>>8), + byte(t.id>>0), + ) + + if t.appendEncode != nil { + return t.appendEncode(b, v) + } + encoded, err := t.encode(v) + if err != nil { + return nil, err + } + return append(b, encoded...), nil +} + +// MustEncode returns the value of Encode, panicking on error. +func (s *Serde) MustEncode(v interface{}) []byte { + b, err := s.Encode(v) + if err != nil { + panic(err) + } + return b +} + +// MustAppendEncode returns the value of AppendEncode, panicking on error. +func (s *Serde) MustAppendEncode(b []byte, v interface{}) []byte { + b, err := s.AppendEncode(b, v) + if err != nil { + panic(err) + } + return b +} + +// Decode decodes b into a value returned by NewValueFn. If DecodeFn or +// NewValueFn options were not used, this returns ErrNotRegistered. +// +// Serde does not handle references in schemas; it is up to you to register the +// full decode function for any top-level ID, regardless of how many other +// schemas are referenced in top-level ID. +func (s *Serde) Decode(b []byte) (interface{}, error) { + if len(b) < 5 || b[0] != 0 { + return nil, ErrBadHeader + } + id := binary.BigEndian.Uint32(b[1:5]) + + t, ok := s.loadIDs()[int(id)] + if !ok || t.decode == nil || t.mk == nil { + return b, ErrNotRegistered + } + v := t.mk() + return v, t.decode(b[5:], v) +} + +// DecodeInto decodes b into v. If DecodeFn option was not used, this returns +// ErrNotRegistered. +// +// Serde does not handle references in schemas; it is up to you to register the +// full decode function for any top-level ID, regardless of how many other +// schemas are referenced in top-level ID. +func (s *Serde) DecodeInto(b []byte, v interface{}) error { + if len(b) < 5 || b[0] != 0 { + return ErrBadHeader + } + id := binary.BigEndian.Uint32(b[1:5]) + + t, ok := s.loadIDs()[int(id)] + if !ok || t.decode == nil { + return ErrNotRegistered + } + return t.decode(b[5:], v) +}