Skip to content

Commit

Permalink
fix(typings): improve inference of the entity type
Browse files Browse the repository at this point in the history
Due to how `EntityData` was defined, the entity type was sometimes inferred incorrectly to
the type of given `data` variable instead of using the first `entityName` parameter.

Closes #876
  • Loading branch information
B4nan committed Sep 26, 2020
1 parent f459334 commit 67f8015
Show file tree
Hide file tree
Showing 15 changed files with 37 additions and 39 deletions.
2 changes: 1 addition & 1 deletion packages/core/src/EntityManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ export class EntityManager<D extends IDatabaseDriver = IDatabaseDriver> {
}

entity = this.getEntityFactory().create<T>(entityName, data as EntityData<T>, { refresh: options.refresh, merge: true, convertCustomTypes: true });
this.getUnitOfWork().registerManaged(entity, data, options.refresh);
this.getUnitOfWork().registerManaged(entity, data as EntityData<T>, options.refresh);
await this.lockAndPopulate(entityName, entity, where, options);

return entity as Loaded<T, P>;
Expand Down
8 changes: 4 additions & 4 deletions packages/core/src/drivers/DatabaseDriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ export abstract class DatabaseDriver<C extends Connection> implements IDatabaseD
protected constructor(protected readonly config: Configuration,
protected readonly dependencies: string[]) { }

abstract async find<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOptions<T>, ctx?: Transaction): Promise<T[]>;
abstract async find<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOptions<T>, ctx?: Transaction): Promise<EntityData<T>[]>;

abstract async findOne<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOneOptions<T>, ctx?: Transaction): Promise<T | null>;
abstract async findOne<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOneOptions<T>, ctx?: Transaction): Promise<EntityData<T> | null>;

abstract async nativeInsert<T extends AnyEntity<T>>(entityName: string, data: EntityData<T>, ctx?: Transaction): Promise<QueryResult>;

Expand Down Expand Up @@ -55,7 +55,7 @@ export abstract class DatabaseDriver<C extends Connection> implements IDatabaseD
await this.nativeUpdate<T>(coll.owner.constructor.name, coll.owner.__helper!.__primaryKey, data, ctx);
}

mapResult<T extends AnyEntity<T>>(result: EntityData<T>, meta: EntityMetadata, populate: PopulateOptions<T>[] = []): T | null {
mapResult<T extends AnyEntity<T>>(result: EntityData<T>, meta: EntityMetadata, populate: PopulateOptions<T>[] = []): EntityData<T> | null {
if (!result || !meta) {
return null;
}
Expand All @@ -80,7 +80,7 @@ export abstract class DatabaseDriver<C extends Connection> implements IDatabaseD
}
});

return ret as T;
return ret;
}

async connect(): Promise<C> {
Expand Down
6 changes: 3 additions & 3 deletions packages/core/src/drivers/IDatabaseDriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ export interface IDatabaseDriver<C extends Connection = Connection> {
/**
* Finds selection of entities
*/
find<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOptions<T>, ctx?: Transaction): Promise<T[]>;
find<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOptions<T>, ctx?: Transaction): Promise<EntityData<T>[]>;

/**
* Finds single entity (table row, document)
*/
findOne<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOneOptions<T>, ctx?: Transaction): Promise<T | null>;
findOne<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOneOptions<T>, ctx?: Transaction): Promise<EntityData<T> | null>;

nativeInsert<T extends AnyEntity<T>>(entityName: string, data: EntityData<T>, ctx?: Transaction): Promise<QueryResult>;

Expand All @@ -47,7 +47,7 @@ export interface IDatabaseDriver<C extends Connection = Connection> {

aggregate(entityName: string, pipeline: any[]): Promise<any[]>;

mapResult<T extends AnyEntity<T>>(result: EntityData<T>, meta: EntityMetadata, populate?: PopulateOptions<T>[]): T | null;
mapResult<T extends AnyEntity<T>>(result: EntityData<T>, meta: EntityMetadata, populate?: PopulateOptions<T>[]): EntityData<T> | null;

/**
* When driver uses pivot tables for M:N, this method will load identifiers for given collections from them
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/entity/EntityAssigner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ export class EntityAssigner {
return em.getReference(prop.type, item);
}

if (Utils.isObject<T>(item) && options.merge) {
if (Utils.isObject<EntityData<T>>(item) && options.merge) {
return em.merge<T>(prop.type, item);
}

if (Utils.isObject<T>(item)) {
if (Utils.isObject<EntityData<T>>(item)) {
return em.create<T>(prop.type, item);
}

Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/entity/EntityFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export class EntityFactory {
constructor(private readonly unitOfWork: UnitOfWork,
private readonly em: EntityManager) { }

create<T extends AnyEntity<T>, P extends Populate<T> = keyof T>(entityName: EntityName<T>, data: EntityData<T>, options: FactoryOptions = {}): New<T, P> {
create<T extends AnyEntity<T>, P extends Populate<T> = any>(entityName: EntityName<T>, data: EntityData<T>, options: FactoryOptions = {}): New<T, P> {
options.initialized = options.initialized ?? true;

if (Utils.isEntity<T>(data)) {
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/entity/EntityLoader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ export class EntityLoader {
}

private async findChildrenFromPivotTable<T extends AnyEntity<T>>(filtered: T[], prop: EntityProperty, field: keyof T, refresh: boolean, where?: FilterQuery<T>, orderBy?: QueryOrderMap): Promise<AnyEntity[]> {
const ids = filtered.map(e => e.__helper!.__primaryKeys);
const ids = filtered.map((e: AnyEntity<T>) => e.__helper!.__primaryKeys);

if (prop.customType) {
ids.forEach((id, idx) => ids[idx] = QueryHelper.processCustomType(prop, id, this.driver.getPlatform()));
Expand Down
6 changes: 3 additions & 3 deletions packages/core/src/entity/EntityTransformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ export class EntityTransformer {
return visible && !meta.primaryKeys.includes(prop) && !prop.startsWith('_') && !ignoreFields.includes(prop);
}

private static propertyName<T extends AnyEntity<T>>(meta: EntityMetadata<T>, prop: keyof T & string, platform?: Platform): string {
private static propertyName<T extends AnyEntity<T>>(meta: EntityMetadata<T>, prop: keyof T & string, platform?: Platform): keyof T & string {
if (meta.properties[prop].serializedName) {
return meta.properties[prop].serializedName!;
return meta.properties[prop].serializedName as keyof T & string;
}

if (meta.properties[prop].primary && platform) {
return platform.getSerializedPrimaryKeyField(prop);
return platform.getSerializedPrimaryKeyField(prop) as keyof T & string;
}

return prop;
Expand Down
8 changes: 3 additions & 5 deletions packages/core/src/typings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export interface IWrappedEntity<T extends AnyEntity<T>, PK extends keyof T, P =
assign(data: any, options?: AssignOptions | boolean): T;
}

export interface IWrappedEntityInternal<T extends AnyEntity<T>, PK extends keyof T, P = keyof T> extends IWrappedEntity<T, PK, P> {
export interface IWrappedEntityInternal<T, PK extends keyof T, P = keyof T> extends IWrappedEntity<T, PK, P> {
hasPrimaryKey(): boolean;
__meta: EntityMetadata<T>;
__data: Dictionary;
Expand All @@ -95,7 +95,7 @@ export interface IWrappedEntityInternal<T extends AnyEntity<T>, PK extends keyof
__serializedPrimaryKey: string & keyof T;
}

export type AnyEntity<T = any> = { [K in keyof T]?: T[K] } & {
export type AnyEntity<T = any> = Partial<T> & {
[PrimaryKeyType]?: unknown;
[EntityRepositoryType]?: unknown;
__helper?: IWrappedEntityInternal<T, keyof T>;
Expand All @@ -107,9 +107,7 @@ export type AnyEntity<T = any> = { [K in keyof T]?: T[K] } & {
export type EntityClass<T extends AnyEntity<T>> = Function & { prototype: T };
export type EntityClassGroup<T extends AnyEntity<T>> = { entity: EntityClass<T>; schema: EntityMetadata<T> | EntitySchema<T> };
export type EntityName<T extends AnyEntity<T>> = string | EntityClass<T> | EntitySchema<T, any>;
export type EntityDataProp<T> = T extends Scalar ? ExpandScalar<T> : (T | EntityData<T> | Primary<T>);
export type CollectionItem<T> = T extends Collection<any> | undefined ? EntityDataProp<ExpandProperty<T>>[] : EntityDataProp<T>;
export type EntityData<T> = T | { [K in keyof T | NonFunctionPropertyNames<T>]?: CollectionItem<T[K]> } & Dictionary;
export type EntityData<T> = { [P in keyof T]?: T[P] | any } & Dictionary;
export type GetRepository<T extends AnyEntity<T>, U> = T[typeof EntityRepositoryType] extends EntityRepository<any> | undefined ? NonNullable<T[typeof EntityRepositoryType]> : U;

export interface EntityProperty<T extends AnyEntity<T> = any> {
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/unit-of-work/ChangeSetComputer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export class ChangeSetComputer {
const data = this.comparator.prepareEntity(entity);

if (entity.__helper!.__originalEntityData) {
return Utils.diff(entity.__helper!.__originalEntityData, data);
return Utils.diff(entity.__helper!.__originalEntityData, data) as EntityData<T>;
}

return data;
Expand Down
8 changes: 4 additions & 4 deletions packages/core/src/unit-of-work/UnitOfWork.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export class UnitOfWork {
private readonly orphanRemoveStack = new Set<AnyEntity>();
private readonly changeSets = new Map<AnyEntity, ChangeSet<AnyEntity>>();
private readonly collectionUpdates = new Set<Collection<AnyEntity>>();
private readonly extraUpdates = new Set<[AnyEntity, string, AnyEntity | Reference<AnyEntity>]>();
private readonly extraUpdates = new Set<[AnyEntity, string, AnyEntity | Reference<any> | Collection<any>]>();
private readonly metadata = this.em.getMetadata();
private readonly platform = this.em.getDriver().getPlatform();
private readonly eventManager = this.em.getEventManager();
Expand Down Expand Up @@ -115,10 +115,10 @@ export class UnitOfWork {
/**
* Returns stored snapshot of entity state that is used for change set computation.
*/
getOriginalEntityData<T extends AnyEntity<T>>(entity?: T): EntityData<T>[] | EntityData<T> | undefined {
getOriginalEntityData<T extends AnyEntity<T>>(entity?: T): EntityData<AnyEntity>[] | EntityData<T> | undefined {
if (!entity) {
return [...this.identityMap.values()].map(e => {
return e.__helper!.__originalEntityData;
return e.__helper!.__originalEntityData!;
});
}

Expand All @@ -141,7 +141,7 @@ export class UnitOfWork {
return [...this.collectionUpdates];
}

getExtraUpdates(): Set<[AnyEntity, string, (AnyEntity | Reference<AnyEntity>)]> {
getExtraUpdates(): Set<[AnyEntity, string, (AnyEntity | Reference<any> | Collection<any>)]> {
return this.extraUpdates;
}

Expand Down
Empty file.
4 changes: 2 additions & 2 deletions packages/core/src/utils/EntityComparator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ export class EntityComparator {
*/
prepareEntity<T extends AnyEntity<T>>(entity: T): EntityData<T> {
if ((entity as Dictionary).__prepared) {
return entity;
return entity as EntityData<T>;
}

const meta = this.metadata.get<T>(entity.constructor.name);
const ret = {} as EntityData<T>;

if (meta.discriminatorValue) {
ret[meta.root.discriminatorColumn as keyof T] = meta.discriminatorValue as unknown as T[keyof T];
ret[meta.root.discriminatorColumn as keyof T] = meta.discriminatorValue as unknown as EntityData<T>[keyof T];
}

// copy all comparable props, ignore collections and references, process custom types
Expand Down
12 changes: 6 additions & 6 deletions packages/knex/src/AbstractSqlDriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra
return new SqlEntityManager(this.config, this, this.metadata, useContext) as unknown as EntityManager<D>;
}

async find<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options: FindOptions<T> = {}, ctx?: Transaction<KnexTransaction>): Promise<T[]> {
async find<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options: FindOptions<T> = {}, ctx?: Transaction<KnexTransaction>): Promise<EntityData<T>[]> {
options = { populate: [], orderBy: {}, ...options };
const meta = this.metadata.find<T>(entityName)!;
const populate = this.autoJoinOneToOneOwner(meta, options.populate as PopulateOptions<T>[]);
Expand Down Expand Up @@ -67,7 +67,7 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra
return result;
}

async findOne<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOneOptions<T>, ctx?: Transaction<KnexTransaction>): Promise<T | null> {
async findOne<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options?: FindOneOptions<T>, ctx?: Transaction<KnexTransaction>): Promise<EntityData<T> | null> {
const opts = { populate: [], ...(options || {}) } as FindOptions<T>;
const meta = this.metadata.find(entityName)!;
const populate = this.autoJoinOneToOneOwner(meta, opts.populate as PopulateOptions<T>[]);
Expand All @@ -82,7 +82,7 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra
return res[0] || null;
}

mapResult<T extends AnyEntity<T>>(result: EntityData<T>, meta: EntityMetadata<T>, populate: PopulateOptions<T>[] = [], qb?: QueryBuilder<T>, map: Dictionary = {}): T | null {
mapResult<T extends AnyEntity<T>>(result: EntityData<T>, meta: EntityMetadata<T>, populate: PopulateOptions<T>[] = [], qb?: QueryBuilder<T>, map: Dictionary = {}): EntityData<T> | null {
const ret = super.mapResult(result, meta);

if (!ret) {
Expand All @@ -93,7 +93,7 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra
this.mapJoinedProps<T>(ret, meta, populate, qb, ret, map);
}

return ret as T;
return ret;
}

private mapJoinedProps<T extends AnyEntity<T>>(result: EntityData<T>, meta: EntityMetadata<T>, populate: PopulateOptions<T>[], qb: QueryBuilder<T>, root: EntityData<T>, map: Dictionary, parentJoinPath?: string) {
Expand Down Expand Up @@ -351,7 +351,7 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra
});
}

protected mergeJoinedResult<T extends AnyEntity<T>>(rawResults: Dictionary[], meta: EntityMetadata<T>): T[] {
protected mergeJoinedResult<T extends AnyEntity<T>>(rawResults: Dictionary[], meta: EntityMetadata<T>): EntityData<T>[] {
// group by the root entity primary key first
const res = rawResults.reduce((result, item) => {
const pk = Utils.getCompositeKeyHash<T>(item as T, meta);
Expand All @@ -361,7 +361,7 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra
return result;
}, {}) as Dictionary<any[]>;

return Object.values(res).map((rows: Dictionary[]) => rows[0]) as T[];
return Object.values(res).map((rows: Dictionary[]) => rows[0]) as EntityData<T>[];
}

protected getFieldsForJoinedLoad<T extends AnyEntity<T>>(qb: QueryBuilder<T>, meta: EntityMetadata<T>, populate: PopulateOptions<T>[] = [], parentTableAlias?: string, parentJoinPath?: string): Field<T>[] {
Expand Down
6 changes: 3 additions & 3 deletions packages/mongodb/src/MongoConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
import { inspect } from 'util';
import {
Connection, ConnectionConfig, QueryResult, Transaction, Utils, QueryOrder, QueryOrderMap,
FilterQuery, AnyEntity, EntityName, Dictionary,
FilterQuery, AnyEntity, EntityName, Dictionary, EntityData,
} from '@mikro-orm/core';

export class MongoConnection extends Connection {
Expand Down Expand Up @@ -76,7 +76,7 @@ export class MongoConnection extends Connection {
throw new Error(`${this.constructor.name} does not support generic execute method`);
}

async find<T extends AnyEntity<T>>(collection: string, where: FilterQuery<T>, orderBy?: QueryOrderMap, limit?: number, offset?: number, fields?: string[], ctx?: Transaction<ClientSession>): Promise<T[]> {
async find<T extends AnyEntity<T>>(collection: string, where: FilterQuery<T>, orderBy?: QueryOrderMap, limit?: number, offset?: number, fields?: string[], ctx?: Transaction<ClientSession>): Promise<EntityData<T>[]> {
collection = this.getCollectionName(collection);
const options: Dictionary = { session: ctx };

Expand Down Expand Up @@ -110,7 +110,7 @@ export class MongoConnection extends Connection {
const res = await resultSet.toArray();
this.logQuery(`${query}.toArray();`, Date.now() - now);

return res;
return res as EntityData<T>[];
}

async insertOne<T extends { _id: any }>(collection: string, data: Partial<T>, ctx?: Transaction<ClientSession>): Promise<QueryResult> {
Expand Down
6 changes: 3 additions & 3 deletions packages/mongodb/src/MongoDriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ export class MongoDriver extends DatabaseDriver<MongoConnection> {
return new MongoEntityManager(this.config, this, this.metadata, useContext) as unknown as EntityManager<D>;
}

async find<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options: FindOptions<T> = {}, ctx?: Transaction<ClientSession>): Promise<T[]> {
async find<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options: FindOptions<T> = {}, ctx?: Transaction<ClientSession>): Promise<EntityData<T>[]> {
const fields = this.buildFields(entityName, options.populate as PopulateOptions<T>[] || [], options.fields);
where = this.renameFields(entityName, where);
const res = await this.rethrow(this.getConnection('read').find<T>(entityName, where, options.orderBy, options.limit, options.offset, fields, ctx));

return res.map((r: T) => this.mapResult<T>(r, this.metadata.find(entityName)!)!);
return res.map(r => this.mapResult<T>(r, this.metadata.find(entityName)!)!);
}

async findOne<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options: FindOneOptions<T> = { populate: [], orderBy: {} }, ctx?: Transaction<ClientSession>): Promise<T | null> {
async findOne<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options: FindOneOptions<T> = { populate: [], orderBy: {} }, ctx?: Transaction<ClientSession>): Promise<EntityData<T> | null> {
if (Utils.isPrimaryKey(where)) {
where = this.buildFilterById(entityName, where as string);
}
Expand Down

0 comments on commit 67f8015

Please sign in to comment.