Skip to content

Commit

Permalink
feat(server): separate face search relation (immich-app#10371)
Browse files Browse the repository at this point in the history
* wip

* various fixes

* new migration

* fix test

* add face search entity, update sql

* update e2e

* set storage to external
  • Loading branch information
mertalev authored Jun 16, 2024
1 parent 0fe152b commit 6b1b505
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 47 deletions.
9 changes: 1 addition & 8 deletions e2e/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,14 +398,7 @@ export const utils = {
return;
}

const vector = Array.from({ length: 512 }, Math.random);
const embedding = `[${vector.join(',')}]`;

await client.query('INSERT INTO asset_faces ("assetId", "personId", "embedding") VALUES ($1, $2, $3)', [
assetId,
personId,
embedding,
]);
await client.query('INSERT INTO asset_faces ("assetId", "personId") VALUES ($1, $2)', [assetId, personId]);
},

setPersonThumbnail: async (personId: string) => {
Expand Down
8 changes: 4 additions & 4 deletions server/src/entities/asset-face.entity.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { AssetEntity } from 'src/entities/asset.entity';
import { FaceSearchEntity } from 'src/entities/face-search.entity';
import { PersonEntity } from 'src/entities/person.entity';
import { Column, Entity, Index, ManyToOne, PrimaryGeneratedColumn } from 'typeorm';
import { Column, Entity, Index, ManyToOne, OneToOne, PrimaryGeneratedColumn } from 'typeorm';

@Entity('asset_faces', { synchronize: false })
@Index('IDX_asset_faces_assetId_personId', ['assetId', 'personId'])
Expand All @@ -15,9 +16,8 @@ export class AssetFaceEntity {
@Column({ nullable: true, type: 'uuid' })
personId!: string | null;

@Index('face_index', { synchronize: false })
@Column({ type: 'float4', array: true, select: false, transformer: { from: (v) => JSON.parse(v), to: (v) => v } })
embedding!: number[];
@OneToOne(() => FaceSearchEntity, (faceSearchEntity) => faceSearchEntity.face, { cascade: ['insert'] })
faceSearch?: FaceSearchEntity;

@Column({ default: 0, type: 'int' })
imageWidth!: number;
Expand Down
21 changes: 21 additions & 0 deletions server/src/entities/face-search.entity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { AssetFaceEntity } from 'src/entities/asset-face.entity';
import { asVector } from 'src/utils/database';
import { Column, Entity, Index, JoinColumn, OneToOne, PrimaryColumn } from 'typeorm';

@Entity('face_search', { synchronize: false })
export class FaceSearchEntity {
@OneToOne(() => AssetFaceEntity, { onDelete: 'CASCADE', nullable: true })
@JoinColumn({ name: 'faceId', referencedColumnName: 'id' })
face?: AssetFaceEntity;

@PrimaryColumn()
faceId!: string;

@Index('face_index', { synchronize: false })
@Column({
type: 'float4',
array: true,
transformer: { from: (v) => JSON.parse(v), to: (v) => asVector(v) },
})
embedding!: number[];
}
2 changes: 2 additions & 0 deletions server/src/entities/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { AssetStackEntity } from 'src/entities/asset-stack.entity';
import { AssetEntity } from 'src/entities/asset.entity';
import { AuditEntity } from 'src/entities/audit.entity';
import { ExifEntity } from 'src/entities/exif.entity';
import { FaceSearchEntity } from 'src/entities/face-search.entity';
import { GeodataPlacesEntity } from 'src/entities/geodata-places.entity';
import { LibraryEntity } from 'src/entities/library.entity';
import { MemoryEntity } from 'src/entities/memory.entity';
Expand All @@ -34,6 +35,7 @@ export const entities = [
AssetJobStatusEntity,
AuditEntity,
ExifEntity,
FaceSearchEntity,
GeodataPlacesEntity,
MemoryEntity,
MoveEntity,
Expand Down
54 changes: 54 additions & 0 deletions server/src/migrations/1718486162779-AddFaceSearchRelation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { getVectorExtension } from 'src/database.config';
import { DatabaseExtension } from 'src/interfaces/database.interface';
import { MigrationInterface, QueryRunner } from 'typeorm';

export class AddFaceSearchRelation1718486162779 implements MigrationInterface {
public async up(queryRunner: QueryRunner): Promise<void> {
if (getVectorExtension() === DatabaseExtension.VECTORS) {
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(`SET vectors.pgvector_compatibility=on`);
}

await queryRunner.query(`
CREATE TABLE face_search (
"faceId" uuid PRIMARY KEY REFERENCES asset_faces(id) ON DELETE CASCADE,
embedding vector(512) NOT NULL )`);

await queryRunner.query(`ALTER TABLE face_search ALTER COLUMN embedding SET STORAGE EXTERNAL`);
await queryRunner.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET STORAGE EXTERNAL`);

await queryRunner.query(`
INSERT INTO face_search("faceId", embedding)
SELECT id, embedding
FROM asset_faces faces`);

await queryRunner.query(`ALTER TABLE asset_faces DROP COLUMN "embedding"`);

await queryRunner.query(`
CREATE INDEX face_index ON face_search
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
}

public async down(queryRunner: QueryRunner): Promise<void> {
if (getVectorExtension() === DatabaseExtension.VECTORS) {
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(`SET vectors.pgvector_compatibility=on`);
}

await queryRunner.query(`ALTER TABLE asset_faces ADD COLUMN "embedding" vector(512)`);
await queryRunner.query(`ALTER TABLE face_search ALTER COLUMN embedding SET STORAGE DEFAULT`);
await queryRunner.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET STORAGE DEFAULT`);
await queryRunner.query(`
UPDATE asset_faces
SET embedding = fs.embedding
FROM face_search fs
WHERE id = fs."faceId"`);
await queryRunner.query(`DROP TABLE face_search`);

await queryRunner.query(`
CREATE INDEX face_index ON asset_faces
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
}
}
5 changes: 3 additions & 2 deletions server/src/queries/search.repository.sql
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,16 @@ WITH
"faces"."boundingBoxY1" AS "boundingBoxY1",
"faces"."boundingBoxX2" AS "boundingBoxX2",
"faces"."boundingBoxY2" AS "boundingBoxY2",
"faces"."embedding" <= > $1 AS "distance"
"search"."embedding" <= > $1 AS "distance"
FROM
"asset_faces" "faces"
INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId"
AND ("asset"."deletedAt" IS NULL)
INNER JOIN "face_search" "search" ON "search"."faceId" = "faces"."id"
WHERE
"asset"."ownerId" IN ($2)
ORDER BY
"faces"."embedding" <= > $1 ASC
"search"."embedding" <= > $1 ASC
LIMIT
100
)
Expand Down
2 changes: 1 addition & 1 deletion server/src/repositories/database.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ export class DatabaseRepository implements IDatabaseRepository {
} catch (error) {
if (getVectorExtension() === DatabaseExtension.VECTORS) {
this.logger.warn(`Could not reindex index ${index}. Attempting to auto-fix.`);
const table = index === VectorIndex.CLIP ? 'smart_search' : 'asset_faces';
const table = index === VectorIndex.CLIP ? 'smart_search' : 'face_search';
const dimSize = await this.getDimSize(table);
await this.dataSource.manager.transaction(async (manager) => {
await this.setSearchPath(manager);
Expand Down
7 changes: 2 additions & 5 deletions server/src/repositories/person.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import {
PersonStatistics,
UpdateFacesData,
} from 'src/interfaces/person.interface';
import { asVector } from 'src/utils/database';
import { Instrumentation } from 'src/utils/instrumentation';
import { Paginated, PaginationOptions, paginate } from 'src/utils/pagination';
import { FindManyOptions, FindOptionsRelations, FindOptionsSelect, In, Repository } from 'typeorm';
Expand Down Expand Up @@ -249,10 +248,8 @@ export class PersonRepository implements IPersonRepository {
}

async createFaces(entities: AssetFaceEntity[]): Promise<string[]> {
const res = await this.assetFaceRepository.insert(
entities.map((entity) => ({ ...entity, embedding: () => asVector(entity.embedding, true) })),
);
return res.identifiers.map((row) => row.id);
const res = await this.assetFaceRepository.save(entities);
return res.map((row) => row.id);
}

async update(entity: Partial<PersonEntity>): Promise<PersonEntity> {
Expand Down
5 changes: 3 additions & 2 deletions server/src/repositories/search.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@ export class SearchRepository implements ISearchRepository {
await this.assetRepository.manager.transaction(async (manager) => {
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('faces.embedding <=> :embedding', 'distance')
.select('search.embedding <=> :embedding', 'distance')
.innerJoin('faces.asset', 'asset')
.innerJoin('faces.faceSearch', 'search')
.where('asset.ownerId IN (:...userIds )')
.orderBy('faces.embedding <=> :embedding')
.orderBy('search.embedding <=> :embedding')
.setParameters({ userIds, embedding: asVector(embedding) });

cte.limit(numResults);
Expand Down
5 changes: 4 additions & 1 deletion server/src/services/person.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -668,15 +668,18 @@ describe(PersonService.name, () => {
machineLearningMock.detectFaces.mockResolvedValue(detectFaceMock);
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
assetMock.getByIds.mockResolvedValue([assetStub.image]);
const faceId = 'face-id';
cryptoMock.randomUUID.mockReturnValue(faceId);
const face = {
id: faceId,
assetId: 'asset-id',
embedding: [1, 2, 3, 4],
boundingBoxX1: 100,
boundingBoxY1: 100,
boundingBoxX2: 200,
boundingBoxY2: 200,
imageHeight: 500,
imageWidth: 400,
faceSearch: { faceId, embedding: [1, 2, 3, 4] },
};

await sut.handleDetectFaces({ id: assetStub.image.id });
Expand Down
41 changes: 26 additions & 15 deletions server/src/services/person.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
mapFaces,
mapPerson,
} from 'src/dtos/person.dto';
import { AssetFaceEntity } from 'src/entities/asset-face.entity';
import { AssetEntity, AssetType } from 'src/entities/asset.entity';
import { PersonPathType } from 'src/entities/move.entity';
import { PersonEntity } from 'src/entities/person.entity';
Expand Down Expand Up @@ -70,7 +71,7 @@ export class PersonService {
@Inject(IStorageRepository) private storageRepository: IStorageRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(ISearchRepository) private smartInfoRepository: ISearchRepository,
@Inject(ICryptoRepository) cryptoRepository: ICryptoRepository,
@Inject(ICryptoRepository) private cryptoRepository: ICryptoRepository,
@Inject(ILoggerRepository) private logger: ILoggerRepository,
) {
this.access = AccessCore.create(accessRepository);
Expand Down Expand Up @@ -347,16 +348,21 @@ export class PersonService {

if (faces.length > 0) {
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
const mappedFaces = faces.map((face) => ({
assetId: asset.id,
embedding: face.embedding,
imageHeight,
imageWidth,
boundingBoxX1: face.boundingBox.x1,
boundingBoxY1: face.boundingBox.y1,
boundingBoxX2: face.boundingBox.x2,
boundingBoxY2: face.boundingBox.y2,
}));
const mappedFaces: Partial<AssetFaceEntity>[] = [];
for (const face of faces) {
const faceId = this.cryptoRepository.randomUUID();
mappedFaces.push({
id: faceId,
assetId: asset.id,
imageHeight,
imageWidth,
boundingBoxX1: face.boundingBox.x1,
boundingBoxY1: face.boundingBox.y1,
boundingBoxX2: face.boundingBox.x2,
boundingBoxY2: face.boundingBox.y2,
faceSearch: { faceId, embedding: face.embedding },
});
}

const faceIds = await this.repository.createFaces(mappedFaces);
await this.jobRepository.queueAll(faceIds.map((id) => ({ name: JobName.FACIAL_RECOGNITION, data: { id } })));
Expand Down Expand Up @@ -409,22 +415,27 @@ export class PersonService {

const face = await this.repository.getFaceByIdWithAssets(
id,
{ person: true, asset: true },
{ id: true, personId: true, embedding: true },
{ person: true, asset: true, faceSearch: true },
{ id: true, personId: true, faceSearch: { embedding: true } },
);
if (!face || !face.asset) {
this.logger.warn(`Face ${id} not found`);
return JobStatus.FAILED;
}

if (!face.faceSearch?.embedding) {
this.logger.warn(`Face ${id} does not have an embedding`);
return JobStatus.FAILED;
}

if (face.personId) {
this.logger.debug(`Face ${id} already has a person assigned`);
return JobStatus.SKIPPED;
}

const matches = await this.smartInfoRepository.searchFaces({
userIds: [face.asset.ownerId],
embedding: face.embedding,
embedding: face.faceSearch.embedding,
maxDistance: machineLearning.facialRecognition.maxDistance,
numResults: machineLearning.facialRecognition.minFaces,
});
Expand All @@ -448,7 +459,7 @@ export class PersonService {
if (!personId) {
const matchWithPerson = await this.smartInfoRepository.searchFaces({
userIds: [face.asset.ownerId],
embedding: face.embedding,
embedding: face.faceSearch.embedding,
maxDistance: machineLearning.facialRecognition.maxDistance,
numResults: 1,
hasPerson: true,
Expand Down
Loading

0 comments on commit 6b1b505

Please sign in to comment.