Skip to content

Commit

Permalink
V2: Clean up reflect types (#884)
Browse files Browse the repository at this point in the history
  • Loading branch information
timostamm authored Jun 13, 2024
1 parent 465aea3 commit fe29cca
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 87 deletions.
5 changes: 0 additions & 5 deletions packages/protobuf-test/src/reflect/reflect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ describe("ReflectMessage", () => {
});
describe("returns error setting undefined", () => {
test.each(desc.fields)("for proto3 field $name", (f) => {
// @ts-expect-error ignore to test runtime behavior
const err = catchFieldError(() => r.set(f, undefined));
expect(err).toBeDefined();
expect(err?.message).toMatch(/^expected .*, got undefined/);
Expand All @@ -364,7 +363,6 @@ describe("ReflectMessage", () => {
});
describe("returns error setting null", () => {
test.each(desc.fields)("for proto3 field $name", (f) => {
// @ts-expect-error ignore to test runtime behavior
const err = catchFieldError(() => r.set(f, null));
expect(err).toBeDefined();
expect(err?.message).toMatch(/^expected .*, got null/);
Expand All @@ -373,15 +371,13 @@ describe("ReflectMessage", () => {
});
describe("throws error setting array", () => {
test.each(desc.fields)("$name", (f) => {
// @ts-expect-error ignore to test runtime behavior
const err = catchFieldError(() => r.set(f, [1, 2]));
expect(err?.message).toMatch(/^expected .*, got Array\(2\)$/);
expect(err?.name).toMatch("FieldValueInvalidError");
});
});
describe("throws error setting object", () => {
test.each(desc.fields)("$name", (f) => {
// @ts-expect-error ignore for test
const err = catchFieldError(() => r.set(f, new Date()));
expect(err?.message).toMatch(/^expected .*, got object$/);
expect(err?.name).toMatch("FieldValueInvalidError");
Expand All @@ -390,7 +386,6 @@ describe("ReflectMessage", () => {
describe("throws error setting message", () => {
test.each(desc.fields)("$name", (f) => {
const err = catchFieldError(() =>
// @ts-expect-error ignore to test runtime behavior
r.set(f, create(proto3_ts.Proto3MessageDesc)),
);
expect(err?.message).toMatch(
Expand Down
1 change: 0 additions & 1 deletion packages/protobuf/src/clone.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ function cloneReflect(i: ReflectMessage): ReflectMessage {
// eslint-disable-next-line no-case-declarations
const map = o.get(f);
for (const entry of i.get(f).entries()) {
// @ts-expect-error TODO fix type error
map.set(entry[0], cloneSingular(f, entry[1]));
}
break;
Expand Down
26 changes: 10 additions & 16 deletions packages/protobuf/src/equals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
// limitations under the License.

import type { MessageShape } from "./types.js";
import { scalarEquals } from "./reflect/scalar.js";
import { scalarEquals, type ScalarValue } from "./reflect/scalar.js";
import { reflect } from "./reflect/reflect.js";
import type { DescField, DescMessage } from "./descriptors.js";
import type { MapEntryKey, ReflectMessage } from "./reflect/index.js";
import type { ReflectMessage } from "./reflect/index.js";

/**
* Compare two messages of the same type.
Expand Down Expand Up @@ -71,7 +71,7 @@ function fieldEquals(
case "map": {
const ma = a.get(f);
const mb = b.get(f);
const keysA: MapEntryKey[] = [];
const keysA: unknown[] = [];
for (const k of ma.keys()) {
if (!mb.has(k)) {
return false;
Expand All @@ -93,16 +93,12 @@ function fieldEquals(
case "enum":
return false;
case "message":
// TODO fix type error
// @ts-expect-error TODO
if (!reflectEquals(va, vb)) {
if (!reflectEquals(va as ReflectMessage, vb as ReflectMessage)) {
return false;
}
break;
case "scalar":
// TODO fix type error
// @ts-expect-error TODO
if (!scalarEquals(f.scalar, va, vb)) {
if (!scalarEquals(f.scalar, va as ScalarValue, vb as ScalarValue)) {
return false;
}
break;
Expand All @@ -117,23 +113,21 @@ function fieldEquals(
return false;
}
for (let i = 0; i < la.size; i++) {
if (la.get(i) === lb.get(i)) {
const va = la.get(i);
const vb = lb.get(i);
if (va === vb) {
continue;
}
switch (f.listKind) {
case "enum":
return false;
case "message":
// TODO fix type error
// @ts-expect-error TODO
if (!reflectEquals(la.get(i), lb.get(i))) {
if (!reflectEquals(va as ReflectMessage, vb as ReflectMessage)) {
return false;
}
break;
case "scalar":
// TODO fix type error
// @ts-expect-error TODO
if (!scalarEquals(f.scalar, la.get(i), lb.get(i))) {
if (!scalarEquals(f.scalar, va as ScalarValue, vb as ScalarValue)) {
return false;
}
break;
Expand Down
5 changes: 2 additions & 3 deletions packages/protobuf/src/from-binary.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import { type DescField, type DescMessage, ScalarType } from "./descriptors.js";
import type { MessageShape } from "./types.js";
import type {
MapEntryKey,
ReflectList,
ReflectMap,
ReflectMessage,
Expand Down Expand Up @@ -175,14 +174,14 @@ function readMapEntry(
options: BinaryReadOptions,
): void {
const field = map.field();
let key: MapEntryKey | undefined,
let key: ScalarValue | undefined,
val: ScalarValue | ReflectMessage | undefined;
const end = reader.pos + reader.uint32();
while (reader.pos < end) {
const [fieldNo] = reader.tag();
switch (fieldNo) {
case 1:
key = readScalar(reader, field.mapKey) as MapEntryKey;
key = readScalar(reader, field.mapKey);
break;
case 2:
switch (field.mapKind) {
Expand Down
4 changes: 0 additions & 4 deletions packages/protobuf/src/from-json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,6 @@ function readMapField(map: ReflectMap, json: JsonValue, opts: JsonReadOptions) {
break;
}
const key = mapKeyFromJson(field.mapKey, jsonMapKey);
// TODO fix types
// @ts-expect-error TODO
map.set(key, value);
}
}
Expand Down Expand Up @@ -376,8 +374,6 @@ function readScalarField(
if (scalarValue === tokenNull) {
msg.clear(field);
} else {
// TODO fix type error
// @ts-expect-error TODO
msg.set(field, scalarValue);
}
}
Expand Down
36 changes: 4 additions & 32 deletions packages/protobuf/src/reflect/reflect-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ export interface ReflectMessage {
* Note that get() never returns `undefined`. To determine whether a field is
* set, use isSet().
*/
get<Field extends DescField>(field: Field): ReflectGetValue<Field>;
get<Field extends DescField>(field: Field): ReflectMessageGet<Field>;

/**
* Set a field value.
Expand All @@ -137,10 +137,7 @@ export interface ReflectMessage {
* Throws an error if the value is invalid for the field. `undefined` is not
* a valid value. To reset a field, use clear().
*/
set<Field extends DescField>(
field: Field,
value: ReflectSetValue<Field>,
): void;
set<Field extends DescField>(field: Field, value: unknown): void;

/**
* Returns the unknown fields of the message.
Expand Down Expand Up @@ -223,7 +220,7 @@ export interface ReflectList<V = unknown> extends Iterable<V> {
* ReflectMap converts keys to their closest possible type in TypeScript.
* - Messages are wrapped in a ReflectMessage.
*/
export interface ReflectMap<K extends MapEntryKey = MapEntryKey, V = unknown>
export interface ReflectMap<K = unknown, V = unknown>
extends ReadonlyMap<K, V> {
/**
* Returns the map field.
Expand All @@ -250,36 +247,11 @@ export interface ReflectMap<K extends MapEntryKey = MapEntryKey, V = unknown>
[unsafeLocal]: Record<string, unknown>;
}

/**
* A ReflectMap key.
*/
export type MapEntryKey = string | number | bigint | boolean;

type enumVal = number;

/**
* The return type of ReflectMessage.get()
*/
// prettier-ignore
export type ReflectGetValue<Field extends DescField = DescField> = (
Field extends { fieldKind: "map" } ? (
Field extends { mapKind: "message" } ? ReflectMap<MapEntryKey, ReflectMessage> :
Field extends { mapKind: "enum" } ? ReflectMap<MapEntryKey, enumVal> :
Field extends { mapKind: "scalar"; scalar: infer T } ? ReflectMap<MapEntryKey, ScalarValue<T>> :
never
) :
Field extends { fieldKind: "list" } ? ReflectList :
Field extends { fieldKind: "enum" } ? number :
Field extends { fieldKind: "message" } ? ReflectMessage :
Field extends { fieldKind: "scalar"; scalar: infer T } ? ScalarValue<T> :
never
);

/**
* The type of the "value" argument of ReflectMessage.set()
*/
// prettier-ignore
export type ReflectSetValue<Field extends DescField = DescField> = (
export type ReflectMessageGet<Field extends DescField = DescField> = (
Field extends { fieldKind: "map" } ? ReflectMap :
Field extends { fieldKind: "list" } ? ReflectList :
Field extends { fieldKind: "enum" } ? number :
Expand Down
32 changes: 14 additions & 18 deletions packages/protobuf/src/reflect/reflect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ import type { Message, MessageShape, UnknownField } from "../types.js";
import { checkField, checkListItem, checkMapEntry } from "./reflect-check.js";
import { FieldError } from "./error.js";
import type {
MapEntryKey,
ReflectGetValue,
ReflectMessageGet,
ReflectList,
ReflectMap,
ReflectMessage,
ReflectSetValue,
} from "./reflect-types.js";
import {
unsafeClear,
Expand Down Expand Up @@ -58,10 +56,10 @@ export function reflect<Desc extends DescMessage>(
*/
check = true,
): ReflectMessage {
return new ReflectMessageImpl<Desc>(messageDesc, message, check);
return new ReflectMessageImpl(messageDesc, message, check);
}

class ReflectMessageImpl<Desc extends DescMessage> implements ReflectMessage {
class ReflectMessageImpl implements ReflectMessage {
readonly desc: DescMessage;
readonly fields: readonly DescField[];
get sortedFields() {
Expand All @@ -82,7 +80,7 @@ class ReflectMessageImpl<Desc extends DescMessage> implements ReflectMessage {
private lists = new Map<DescField, ReflectList>();
private maps = new Map<DescField, ReflectMap>();

constructor(messageDesc: Desc, message?: MessageShape<Desc>, check = true) {
constructor(messageDesc: DescMessage, message?: Message, check = true) {
this.check = check;
this.desc = messageDesc;
this.message = this[unsafeLocal] = message ?? create(messageDesc);
Expand Down Expand Up @@ -115,7 +113,7 @@ class ReflectMessageImpl<Desc extends DescMessage> implements ReflectMessage {
unsafeClear(this.message, field);
}

get<Field extends DescField>(field: Field): ReflectGetValue<Field> {
get<Field extends DescField>(field: Field): ReflectMessageGet<Field> {
assertOwn(this.message, field);
let value = unsafeGet(this.message, field);
switch (field.fieldKind) {
Expand All @@ -128,7 +126,7 @@ class ReflectMessageImpl<Desc extends DescMessage> implements ReflectMessage {
(list = new ReflectListImpl(field, value as unknown[], this.check)),
);
}
return list as ReflectGetValue<Field>;
return list as ReflectMessageGet<Field>;
case "map":
// eslint-disable-next-line no-case-declarations
let map = this.maps.get(field);
Expand All @@ -142,7 +140,7 @@ class ReflectMessageImpl<Desc extends DescMessage> implements ReflectMessage {
)),
);
}
return map as ReflectGetValue<Field>;
return map as ReflectMessageGet<Field>;
case "message":
if (
value !== undefined &&
Expand All @@ -158,22 +156,20 @@ class ReflectMessageImpl<Desc extends DescMessage> implements ReflectMessage {
field.message,
value as Message | undefined,
this.check,
) as ReflectGetValue<Field>;
) as ReflectMessageGet<Field>;
case "scalar":
return (
value === undefined
? scalarZeroValue(field.scalar, false)
: longToReflect(field, value)
) as ReflectGetValue<Field>;
) as ReflectMessageGet<Field>;
case "enum":
return (value ?? field.enum.values[0].number) as ReflectGetValue<Field>;
return (value ??
field.enum.values[0].number) as ReflectMessageGet<Field>;
}
}

set<Field extends DescField>(
field: Field,
value: ReflectSetValue<Field>,
): void {
set<Field extends DescField>(field: Field, value: unknown): void {
assertOwn(this.message, field);
if (this.check) {
const err = checkField(field, value);
Expand Down Expand Up @@ -307,7 +303,7 @@ class ReflectListImpl<V> implements ReflectList<V> {
/**
* Create a ReflectMap.
*/
export function reflectMap<K extends MapEntryKey, V>(
export function reflectMap<K = unknown, V = unknown>(
field: DescField & { fieldKind: "map" },
unsafeInput?: Record<string, unknown>,
/**
Expand All @@ -322,7 +318,7 @@ export function reflectMap<K extends MapEntryKey, V>(
return new ReflectMapImpl(field, unsafeInput, check);
}

class ReflectMapImpl<K extends MapEntryKey, V> implements ReflectMap<K, V> {
class ReflectMapImpl<K, V> implements ReflectMap<K, V> {
private readonly check: boolean;
private readonly _field: DescField & { fieldKind: "map" };
[unsafeLocal]: Record<string, unknown>;
Expand Down
4 changes: 2 additions & 2 deletions packages/protobuf/src/reflect/scalar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ export type ScalarValue<
*/
export function scalarEquals(
type: ScalarType,
a: string | boolean | number | bigint | Uint8Array | undefined,
b: string | boolean | number | bigint | Uint8Array | undefined,
a: ScalarValue | undefined,
b: ScalarValue | undefined,
): boolean {
if (a === b) {
// This correctly matches equal values except BYTES and (possibly) 64-bit integers.
Expand Down
2 changes: 1 addition & 1 deletion packages/protobuf/src/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ function createMutableRegistry(): MutableRegistry {
case "file":
files.set(desc.proto.name, desc);
break;
// @ts-expect-error TS7029
// @ts-expect-error TS7029: Fallthrough case in switch
case "extension":
// eslint-disable-next-line no-case-declarations
let numberToExt = extendees.get(desc.extendee.typeName);
Expand Down
8 changes: 3 additions & 5 deletions packages/protobuf/src/to-json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,20 @@ function mapToJson(map: ReflectMap, opts: JsonWriteOptions) {
switch (f.mapKind) {
case "scalar":
for (const [entryKey, entryValue] of map) {
jsonObj[entryKey.toString()] = scalarToJson(f, entryValue); // JSON standard allows only (double quoted) string as property key
jsonObj[entryKey as keyof object] = scalarToJson(f, entryValue);
}
break;
case "message":
for (const [entryKey, entryValue] of map) {
// JSON standard allows only (double quoted) string as property key
jsonObj[entryKey.toString()] = reflectToJson(
jsonObj[entryKey as keyof object] = reflectToJson(
entryValue as ReflectMessage,
opts,
);
}
break;
case "enum":
for (const [entryKey, entryValue] of map) {
// JSON standard allows only (double quoted) string as property key
jsonObj[entryKey.toString()] = enumToJson(
jsonObj[entryKey as keyof object] = enumToJson(
f.enum,
entryValue,
opts.enumAsInteger,
Expand Down

0 comments on commit fe29cca

Please sign in to comment.