Skip to content

Commit

Permalink
✨ Query models / datasets / spaces by tags (#498)
Browse files Browse the repository at this point in the history
Fix #497

## Query repos by tags

Eg for gguf model, `listModels({search: {tags: ["gguf"]}})`

For indonesian models, `listModels({search: {tags: ["id"]}})`

For models stored in the EU, `listModels({search: {tags:
["region:eu"]}})`

And so on.

## Request extra fields

The `additionalFields` param, previously only available for
`listSpaces`, is also available for `listModels` and `listDatasets`.

It's also now strongly typed.

So if you want to get the tags of the models you're listing:

```ts
for await (const model of listModels({additionalFields: ["tags"]})) {
  console.log(model.tags);
}
```

## Limit the number of models requested

You can specify the number of models to fetch, eg:

```ts
for await (const model of listModels({limit: 2})) {
  console.log(model.tags);
}
```

This will only fetch two models
  • Loading branch information
coyotte508 authored Feb 22, 2024
1 parent bea807a commit ee6740d
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 17 deletions.
50 changes: 47 additions & 3 deletions packages/hub/src/lib/list-datasets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,35 @@ import type { ApiDatasetInfo } from "../types/api/api-dataset";
import type { Credentials } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";

const EXPAND_KEYS = ["private", "downloads", "gated", "likes", "lastModified"] satisfies (keyof ApiDatasetInfo)[];
const EXPAND_KEYS = [
"private",
"downloads",
"gated",
"likes",
"lastModified",
] as const satisfies readonly (keyof ApiDatasetInfo)[];

const EXPANDABLE_KEYS = [
"author",
"cardData",
"citation",
"createdAt",
"disabled",
"description",
"downloads",
"downloadsAllTime",
"gated",
"gitalyUid",
"lastModified",
"likes",
"paperswithcode_id",
"private",
// "siblings",
"sha",
"tags",
] as const satisfies readonly (keyof ApiDatasetInfo)[];

export interface DatasetEntry {
id: string;
Expand All @@ -17,24 +44,35 @@ export interface DatasetEntry {
updatedAt: Date;
}

export async function* listDatasets(params?: {
export async function* listDatasets<
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
>(params?: {
search?: {
owner?: string;
tags?: string[];
};
credentials?: Credentials;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): AsyncGenerator<DatasetEntry> {
checkCredentials(params?.credentials);
let totalToFetch = params?.limit ?? Infinity;
const search = new URLSearchParams([
...Object.entries({
limit: "500",
limit: String(Math.min(totalToFetch, 500)),
...(params?.search?.owner ? { author: params.search.owner } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/datasets` + (search ? "?" + search : "");

Expand All @@ -54,6 +92,7 @@ export async function* listDatasets(params?: {

for (const item of items) {
yield {
...(params?.additionalFields && pick(item, params.additionalFields)),
id: item._id,
name: item.id,
private: item.private,
Expand All @@ -62,10 +101,15 @@ export async function* listDatasets(params?: {
gated: item.gated,
updatedAt: new Date(item.lastModified),
};
totalToFetch--;
if (totalToFetch <= 0) {
return;
}
}

const linkHeader = res.headers.get("Link");

url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
// Could update limit in url to fetch less items if not all items of next page are needed.
}
}
15 changes: 15 additions & 0 deletions packages/hub/src/lib/list-models.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,19 @@ describe("listModels", () => {
},
]);
});

it("should list indonesian models with gguf format", async () => {
let count = 0;
for await (const entry of listModels({
search: { tags: ["gguf", "id"] },
additionalFields: ["tags"],
limit: 2,
})) {
count++;
expect(entry.tags).to.include("gguf");
expect(entry.tags).to.include("id");
}

expect(count).to.equal(2);
});
});
53 changes: 48 additions & 5 deletions packages/hub/src/lib/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type { ApiModelInfo } from "../types/api/api-model";
import type { Credentials, PipelineType } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";

const EXPAND_KEYS = [
"pipeline_tag",
Expand All @@ -12,7 +13,31 @@ const EXPAND_KEYS = [
"downloads",
"likes",
"lastModified",
] satisfies (keyof ApiModelInfo)[];
] as const satisfies readonly (keyof ApiModelInfo)[];

const EXPANDABLE_KEYS = [
"author",
"cardData",
"config",
"createdAt",
"disabled",
"downloads",
"downloadsAllTime",
"gated",
"gitalyUid",
"lastModified",
"library_name",
"likes",
"model-index",
"pipeline_tag",
"private",
"safetensors",
"sha",
// "siblings",
"spaces",
"tags",
"transformersInfo",
] as const satisfies readonly (keyof ApiModelInfo)[];

export interface ModelEntry {
id: string;
Expand All @@ -25,26 +50,37 @@ export interface ModelEntry {
updatedAt: Date;
}

export async function* listModels(params?: {
export async function* listModels<
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
>(params?: {
search?: {
owner?: string;
task?: PipelineType;
tags?: string[];
};
credentials?: Credentials;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): AsyncGenerator<ModelEntry> {
}): AsyncGenerator<ModelEntry & Pick<ApiModelInfo, T>> {
checkCredentials(params?.credentials);
let totalToFetch = params?.limit ?? Infinity;
const search = new URLSearchParams([
...Object.entries({
limit: "500",
limit: String(Math.min(totalToFetch, 500)),
...(params?.search?.owner ? { author: params.search.owner } : undefined),
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;

Expand All @@ -64,6 +100,7 @@ export async function* listModels(params?: {

for (const item of items) {
yield {
...(params?.additionalFields && pick(item, params.additionalFields)),
id: item._id,
name: item.id,
private: item.private,
Expand All @@ -72,11 +109,17 @@ export async function* listModels(params?: {
gated: item.gated,
likes: item.likes,
updatedAt: new Date(item.lastModified),
};
} as ModelEntry & Pick<ApiModelInfo, T>;
totalToFetch--;

if (totalToFetch <= 0) {
return;
}
}

const linkHeader = res.headers.get("Link");

url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
// Could update url to reduce the limit if we don't need the whole 500 of the next batch.
}
}
32 changes: 27 additions & 5 deletions packages/hub/src/lib/list-spaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,42 @@ import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";

const EXPAND_KEYS = ["sdk", "likes", "private", "lastModified"];
const EXPAND_KEYS = ["sdk", "likes", "private", "lastModified"] as const satisfies readonly (keyof ApiSpaceInfo)[];
const EXPANDABLE_KEYS = [
"author",
"cardData",
"datasets",
"disabled",
"gitalyUid",
"lastModified",
"createdAt",
"likes",
"private",
"runtime",
"sdk",
// "siblings",
"sha",
"subdomain",
"tags",
"models",
] as const satisfies readonly (keyof ApiSpaceInfo)[];

export type SpaceEntry = {
export interface SpaceEntry {
id: string;
name: string;
sdk?: SpaceSdk;
likes: number;
private: boolean;
updatedAt: Date;
// Use additionalFields to fetch the fields from ApiSpaceInfo
} & Partial<Omit<ApiSpaceInfo, "updatedAt">>;
}

export async function* listSpaces(params?: {
export async function* listSpaces<
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
>(params?: {
search?: {
owner?: string;
tags?: string[];
};
credentials?: Credentials;
hubUrl?: string;
Expand All @@ -31,11 +52,12 @@ export async function* listSpaces(params?: {
/**
* Additional fields to fetch from huggingface.co.
*/
additionalFields?: Array<keyof ApiSpaceInfo>;
additionalFields?: T[];
}): AsyncGenerator<SpaceEntry> {
checkCredentials(params?.credentials);
const search = new URLSearchParams([
...Object.entries({ limit: "500", ...(params?.search?.owner ? { author: params.search.owner } : undefined) }),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...[...EXPAND_KEYS, ...(params?.additionalFields ?? [])].map((val) => ["expand", val] satisfies [string, string]),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/spaces?${search}`;
Expand Down
63 changes: 62 additions & 1 deletion packages/hub/src/types/api/api-dataset.d.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import type { License } from "../public";

export interface ApiDatasetInfo {
_id: string;
id: string;
arxivIds?: string[];
author?: string;
cardExists?: true;
cardError?: unknown;
cardData?: unknown;
cardData?: ApiDatasetMetadata;
contributors?: Array<{ user: string; _id: string }>;
disabled: boolean;
discussionsDisabled: boolean;
Expand All @@ -17,6 +19,9 @@ export interface ApiDatasetInfo {
likesRecent: number;
private: boolean;
updatedAt: string; // date
createdAt: string; // date
tags: string[];
paperswithcode_id?: string;
sha: string;
files?: string[];
citation?: string;
Expand All @@ -26,3 +31,59 @@ export interface ApiDatasetInfo {
previewable?: boolean;
doi?: { id: string; commit: string };
}

export interface ApiDatasetMetadata {
licenses?: undefined;
license?: License | License[];
license_name?: string;
license_link?: "LICENSE" | "LICENSE.md" | string;
license_details?: string;
languages?: undefined;
language?: string | string[];
language_bcp47?: string[];
language_details?: string;
tags?: string[];
task_categories?: string[];
task_ids?: string[];
config_names?: string[];
configs?: {
config_name: string;
data_files?:
| string
| string[]
| {
split: string;
path: string | string[];
}[];
data_dir?: string;
}[];
benchmark?: string;
paperswithcode_id?: string | null;
pretty_name?: string;
viewer?: boolean;
viewer_display_urls?: boolean;
thumbnail?: string | null;
description?: string | null;
annotations_creators?: string[];
language_creators?: string[];
multilinguality?: string[];
size_categories?: string[];
source_datasets?: string[];
extra_gated_prompt?: string;
extra_gated_fields?: {
/**
* "text" | "checkbox" | "date_picker" | "country" | "ip_location" | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } | { type: "select", options: Array<string | { label: string; value: string; }> } Property
*/
[x: string]:
| "text"
| "checkbox"
| "date_picker"
| "country"
| "ip_location"
| { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" }
| { type: "select"; options: Array<string | { label: string; value: string }> };
};
extra_gated_heading?: string;
extra_gated_description?: string;
extra_gated_button_content?: string;
}
Loading

0 comments on commit ee6740d

Please sign in to comment.