Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/src/repositories/ocr.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ export class OcrRepository {
textScore: DummyValue.NUMBER,
},
],
DummyValue.STRING,
],
})
upsert(assetId: string, ocrDataList: Insertable<AssetOcrTable>[]) {
upsert(assetId: string, ocrDataList: Insertable<AssetOcrTable>[], searchText: string) {
let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId));
if (ocrDataList.length > 0) {
const searchText = ocrDataList.map((item) => item.text.trim()).join(' ');
(query as any) = query
.with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList))
.with('inserted_search', (db) =>
Expand Down
24 changes: 24 additions & 0 deletions server/src/schema/migrations/1764483051488-OCRBigramsForCJK.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { Kysely, sql } from 'kysely';
import { tokenizeForSearch } from 'src/utils/database';

export async function up(db: Kysely<any>): Promise<void> {
await sql`truncate ${sql.table('ocr_search')}`.execute(db);
const batch = [];
for await (const { assetId, text } of db
.selectFrom('asset_ocr')
.select(['assetId', sql<string>`string_agg(text, ' ')`.as('text')])
.groupBy('assetId')
.stream()) {
batch.push({ assetId, text: tokenizeForSearch(text) });
if (batch.length >= 5000) {
await db.insertInto('ocr_search').values(batch).execute();
batch.length = 0;
}
}

if (batch.length > 0) {
await db.insertInto('ocr_search').values(batch).execute();
}
}

export async function down(): Promise<void> {}
167 changes: 127 additions & 40 deletions server/src/services/ocr.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,21 @@ describe(OcrService.name, () => {
({ sut, mocks } = newTestService(OcrService));

mocks.config.getWorker.mockReturnValue(ImmichWorker.Microservices);
mocks.assetJob.getForOcr.mockResolvedValue({
visibility: AssetVisibility.Timeline,
previewFile: assetStub.image.files[1].path,
});
});

const mockOcrResult = (...texts: string[]) => {
mocks.machineLearning.ocr.mockResolvedValue({
box: texts.flatMap((_, i) => Array.from({ length: 8 }, (_, j) => i * 10 + j)),
boxScore: texts.map(() => 0.9),
text: texts,
textScore: texts.map(() => 0.95),
});
};

it('should work', () => {
expect(sut).toBeDefined();
});
Expand Down Expand Up @@ -72,10 +85,6 @@ describe(OcrService.name, () => {
text: ['One Two Three', 'Four Five'],
textScore: [0.95, 0.85],
});
mocks.assetJob.getForOcr.mockResolvedValue({
visibility: AssetVisibility.Timeline,
previewFile: assetStub.image.files[1].path,
});

expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success);

Expand All @@ -88,36 +97,40 @@ describe(OcrService.name, () => {
maxResolution: 736,
}),
);
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, [
{
assetId: assetStub.image.id,
boxScore: 0.9,
text: 'One Two Three',
textScore: 0.95,
x1: 10,
y1: 20,
x2: 30,
y2: 40,
x3: 50,
y3: 60,
x4: 70,
y4: 80,
},
{
assetId: assetStub.image.id,
boxScore: 0.8,
text: 'Four Five',
textScore: 0.85,
x1: 90,
y1: 100,
x2: 110,
y2: 120,
x3: 130,
y3: 140,
x4: 150,
y4: 160,
},
]);
expect(mocks.ocr.upsert).toHaveBeenCalledWith(
assetStub.image.id,
[
{
assetId: assetStub.image.id,
boxScore: 0.9,
text: 'One Two Three',
textScore: 0.95,
x1: 10,
y1: 20,
x2: 30,
y2: 40,
x3: 50,
y3: 60,
x4: 70,
y4: 80,
},
{
assetId: assetStub.image.id,
boxScore: 0.8,
text: 'Four Five',
textScore: 0.85,
x1: 90,
y1: 100,
x2: 110,
y2: 120,
x3: 130,
y3: 140,
x4: 150,
y4: 160,
},
],
'One Two Three Four Five',
);
});

it('should apply config settings', async () => {
Expand All @@ -133,11 +146,7 @@ describe(OcrService.name, () => {
},
},
});
mocks.machineLearning.ocr.mockResolvedValue({ box: [], boxScore: [], text: [], textScore: [] });
mocks.assetJob.getForOcr.mockResolvedValue({
visibility: AssetVisibility.Timeline,
previewFile: assetStub.image.files[1].path,
});
mockOcrResult();

expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success);

Expand All @@ -150,7 +159,7 @@ describe(OcrService.name, () => {
maxResolution: 1500,
}),
);
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, []);
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, [], '');
});

it('should skip invisible assets', async () => {
Expand All @@ -173,5 +182,83 @@ describe(OcrService.name, () => {
expect(mocks.machineLearning.ocr).not.toHaveBeenCalled();
expect(mocks.ocr.upsert).not.toHaveBeenCalled();
});

describe('search tokenization', () => {
it('should generate bigrams for Chinese text', async () => {
mockOcrResult('機器學習');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 器學 學習');
});

it('should generate bigrams for Japanese text', async () => {
mockOcrResult('テスト');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'テス スト');
});

it('should generate bigrams for Korean text', async () => {
mockOcrResult('한국어');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '한국 국어');
});

it('should pass through Latin text unchanged', async () => {
mockOcrResult('Hello World');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'Hello World');
});

it('should handle mixed CJK and Latin text', async () => {
mockOcrResult('機器學習Model');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 器學 學習 Model');
});

it('should handle year followed by CJK', async () => {
mockOcrResult('2024年レポート');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(
assetStub.image.id,
expect.any(Array),
'2024 年レ レポ ポー ート',
);
});

it('should join multiple OCR boxes', async () => {
mockOcrResult('機器', 'Learning');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 Learning');
});

it('should normalize whitespace', async () => {
mockOcrResult(' Hello World ');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'Hello World');
});

it('should keep single CJK characters', async () => {
mockOcrResult('A', '中', 'B');

await sut.handleOcr({ id: assetStub.image.id });

expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'A 中 B');
});
});
});
});
13 changes: 9 additions & 4 deletions server/src/services/ocr.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum';
import { OCR } from 'src/repositories/machine-learning.repository';
import { BaseService } from 'src/services/base.service';
import { JobItem, JobOf } from 'src/types';
import { tokenizeForSearch } from 'src/utils/database';
import { isOcrEnabled } from 'src/utils/misc';

@Injectable()
Expand Down Expand Up @@ -53,8 +54,8 @@ export class OcrService extends BaseService {
}

const ocrResults = await this.machineLearningRepository.ocr(asset.previewFile, machineLearning.ocr);

await this.ocrRepository.upsert(id, this.parseOcrResults(id, ocrResults));
const { ocrDataList, searchText } = this.parseOcrResults(id, ocrResults);
await this.ocrRepository.upsert(id, ocrDataList, searchText);

await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() });

Expand All @@ -64,7 +65,9 @@ export class OcrService extends BaseService {

private parseOcrResults(id: string, { box, boxScore, text, textScore }: OCR) {
const ocrDataList = [];
const searchTokens = [];
for (let i = 0; i < text.length; i++) {
const rawText = text[i];
const boxOffset = i * 8;
ocrDataList.push({
assetId: id,
Expand All @@ -78,9 +81,11 @@ export class OcrService extends BaseService {
y4: box[boxOffset + 7],
boxScore: boxScore[i],
textScore: textScore[i],
text: text[i],
text: rawText,
});
searchTokens.push(...tokenizeForSearch(rawText));
}
return ocrDataList;

return { ocrDataList, searchText: searchTokens.join(' ') };
}
}
42 changes: 41 additions & 1 deletion server/src/utils/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,46 @@ export function withTagId<O>(qb: SelectQueryBuilder<DB, 'asset', O>, tagId: stri
);
}

const isCJK = (c: number): boolean =>
(c >= 0x4e_00 && c <= 0x9f_ff) ||
(c >= 0xac_00 && c <= 0xd7_af) ||
(c >= 0x30_40 && c <= 0x30_9f) ||
(c >= 0x30_a0 && c <= 0x30_ff) ||
(c >= 0x34_00 && c <= 0x4d_bf);

export const tokenizeForSearch = (text: string): string[] => {
/* eslint-disable unicorn/prefer-code-point */
const tokens: string[] = [];
let i = 0;
while (i < text.length) {
const c = text.charCodeAt(i);
if (c <= 32) {
i++;
continue;
}

const start = i;
if (isCJK(c)) {
while (i < text.length && isCJK(text.charCodeAt(i))) {
i++;
}
if (i - start === 1) {
tokens.push(text[start]);
} else {
for (let k = start; k < i - 1; k++) {
tokens.push(text[k] + text[k + 1]);
}
}
} else {
while (i < text.length && text.charCodeAt(i) > 32 && !isCJK(text.charCodeAt(i))) {
i++;
}
tokens.push(text.slice(start, i));
}
}
return tokens;
};

const joinDeduplicationPlugin = new DeduplicateJoinsPlugin();
/** TODO: This should only be used for search-related queries, not as a general purpose query builder */

Expand Down Expand Up @@ -391,7 +431,7 @@ export function searchAssetBuilder(kysely: Kysely<DB>, options: AssetSearchBuild
.$if(!!options.ocr, (qb) =>
qb
.innerJoin('ocr_search', 'asset.id', 'ocr_search.assetId')
.where(() => sql`f_unaccent(ocr_search.text) %>> f_unaccent(${options.ocr!})`),
.where(() => sql`f_unaccent(ocr_search.text) %>> f_unaccent(${tokenizeForSearch(options.ocr!).join(' ')})`),
)
.$if(!!options.type, (qb) => qb.where('asset.type', '=', options.type!))
.$if(options.isFavorite !== undefined, (qb) => qb.where('asset.isFavorite', '=', options.isFavorite!))
Expand Down
Loading