From ee6740d2364bd5b9f94e2515b896ec065d3eec91 Mon Sep 17 00:00:00 2001 From: Eliott C Date: Thu, 22 Feb 2024 23:46:50 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Query=20models=20/=20datasets=20/?= =?UTF-8?q?=20spaces=20by=20tags=20(#498)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix https://github.com/huggingface/huggingface.js/issues/497 ## Query repos by tags Eg for gguf model, `listModels({search: {tags: ["gguf"]}})` For indonesian models, `listModels({search: {tags: ["id"]}})` For models stored in the EU, `listModels({search: {tags: ["region:eu"]}})` And so on. ## Request extra fields The `additionalFields` param, previously only available for `listSpaces`, is also available for `listModels` and `listDatasets`. It's also now strongly typed. So if you want to get the tags of the models you're listing: ```ts for await (const model of listModels({additionalFields: ["tags"]})) { console.log(model.tags); } ``` ## Limit the number of models requested You can specify the number of models to fetch, eg: ```ts for await (const model of listModels({limit: 2})) { console.log(model.tags); } ``` This will only fetch two models --- packages/hub/src/lib/list-datasets.ts | 50 ++++++++++- packages/hub/src/lib/list-models.spec.ts | 15 ++++ packages/hub/src/lib/list-models.ts | 53 ++++++++++-- packages/hub/src/lib/list-spaces.ts | 32 ++++++-- packages/hub/src/types/api/api-dataset.d.ts | 63 +++++++++++++- packages/hub/src/types/api/api-model.d.ts | 91 ++++++++++++++++++++- packages/hub/src/types/api/api-space.d.ts | 57 ++++++++++++- packages/hub/src/types/public.d.ts | 70 ++++++++++++++++ 8 files changed, 414 insertions(+), 17 deletions(-) diff --git a/packages/hub/src/lib/list-datasets.ts b/packages/hub/src/lib/list-datasets.ts index 4109f969d..8aca852b7 100644 --- a/packages/hub/src/lib/list-datasets.ts +++ b/packages/hub/src/lib/list-datasets.ts @@ -4,8 +4,35 @@ import type { ApiDatasetInfo } from "../types/api/api-dataset"; import type { Credentials } from "../types/public"; import { checkCredentials } from "../utils/checkCredentials"; import { parseLinkHeader } from "../utils/parseLinkHeader"; +import { pick } from "../utils/pick"; -const EXPAND_KEYS = ["private", "downloads", "gated", "likes", "lastModified"] satisfies (keyof ApiDatasetInfo)[]; +const EXPAND_KEYS = [ + "private", + "downloads", + "gated", + "likes", + "lastModified", +] as const satisfies readonly (keyof ApiDatasetInfo)[]; + +const EXPANDABLE_KEYS = [ + "author", + "cardData", + "citation", + "createdAt", + "disabled", + "description", + "downloads", + "downloadsAllTime", + "gated", + "gitalyUid", + "lastModified", + "likes", + "paperswithcode_id", + "private", + // "siblings", + "sha", + "tags", +] as const satisfies readonly (keyof ApiDatasetInfo)[]; export interface DatasetEntry { id: string; @@ -17,24 +44,35 @@ export interface DatasetEntry { updatedAt: Date; } -export async function* listDatasets(params?: { +export async function* listDatasets< + const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never, +>(params?: { search?: { owner?: string; + tags?: string[]; }; credentials?: Credentials; hubUrl?: string; + additionalFields?: T[]; + /** + * Set to limit the number of models returned. + */ + limit?: number; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; }): AsyncGenerator { checkCredentials(params?.credentials); + let totalToFetch = params?.limit ?? Infinity; const search = new URLSearchParams([ ...Object.entries({ - limit: "500", + limit: String(Math.min(totalToFetch, 500)), ...(params?.search?.owner ? { author: params.search.owner } : undefined), }), + ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []), ...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), + ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []), ]).toString(); let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/datasets` + (search ? "?" + search : ""); @@ -54,6 +92,7 @@ export async function* listDatasets(params?: { for (const item of items) { yield { + ...(params?.additionalFields && pick(item, params.additionalFields)), id: item._id, name: item.id, private: item.private, @@ -62,10 +101,15 @@ export async function* listDatasets(params?: { gated: item.gated, updatedAt: new Date(item.lastModified), }; + totalToFetch--; + if (totalToFetch <= 0) { + return; + } } const linkHeader = res.headers.get("Link"); url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; + // Could update limit in url to fetch less items if not all items of next page are needed. } } diff --git a/packages/hub/src/lib/list-models.spec.ts b/packages/hub/src/lib/list-models.spec.ts index 75990b03a..ab8c7f638 100644 --- a/packages/hub/src/lib/list-models.spec.ts +++ b/packages/hub/src/lib/list-models.spec.ts @@ -52,4 +52,19 @@ describe("listModels", () => { }, ]); }); + + it("should list indonesian models with gguf format", async () => { + let count = 0; + for await (const entry of listModels({ + search: { tags: ["gguf", "id"] }, + additionalFields: ["tags"], + limit: 2, + })) { + count++; + expect(entry.tags).to.include("gguf"); + expect(entry.tags).to.include("id"); + } + + expect(count).to.equal(2); + }); }); diff --git a/packages/hub/src/lib/list-models.ts b/packages/hub/src/lib/list-models.ts index 7fde45d68..2f589b2fa 100644 --- a/packages/hub/src/lib/list-models.ts +++ b/packages/hub/src/lib/list-models.ts @@ -4,6 +4,7 @@ import type { ApiModelInfo } from "../types/api/api-model"; import type { Credentials, PipelineType } from "../types/public"; import { checkCredentials } from "../utils/checkCredentials"; import { parseLinkHeader } from "../utils/parseLinkHeader"; +import { pick } from "../utils/pick"; const EXPAND_KEYS = [ "pipeline_tag", @@ -12,7 +13,31 @@ const EXPAND_KEYS = [ "downloads", "likes", "lastModified", -] satisfies (keyof ApiModelInfo)[]; +] as const satisfies readonly (keyof ApiModelInfo)[]; + +const EXPANDABLE_KEYS = [ + "author", + "cardData", + "config", + "createdAt", + "disabled", + "downloads", + "downloadsAllTime", + "gated", + "gitalyUid", + "lastModified", + "library_name", + "likes", + "model-index", + "pipeline_tag", + "private", + "safetensors", + "sha", + // "siblings", + "spaces", + "tags", + "transformersInfo", +] as const satisfies readonly (keyof ApiModelInfo)[]; export interface ModelEntry { id: string; @@ -25,26 +50,37 @@ export interface ModelEntry { updatedAt: Date; } -export async function* listModels(params?: { +export async function* listModels< + const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never, +>(params?: { search?: { owner?: string; task?: PipelineType; + tags?: string[]; }; credentials?: Credentials; hubUrl?: string; + additionalFields?: T[]; + /** + * Set to limit the number of models returned. + */ + limit?: number; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; -}): AsyncGenerator { +}): AsyncGenerator> { checkCredentials(params?.credentials); + let totalToFetch = params?.limit ?? Infinity; const search = new URLSearchParams([ ...Object.entries({ - limit: "500", + limit: String(Math.min(totalToFetch, 500)), ...(params?.search?.owner ? { author: params.search.owner } : undefined), ...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined), }), + ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []), ...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), + ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []), ]).toString(); let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`; @@ -64,6 +100,7 @@ export async function* listModels(params?: { for (const item of items) { yield { + ...(params?.additionalFields && pick(item, params.additionalFields)), id: item._id, name: item.id, private: item.private, @@ -72,11 +109,17 @@ export async function* listModels(params?: { gated: item.gated, likes: item.likes, updatedAt: new Date(item.lastModified), - }; + } as ModelEntry & Pick; + totalToFetch--; + + if (totalToFetch <= 0) { + return; + } } const linkHeader = res.headers.get("Link"); url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; + // Could update url to reduce the limit if we don't need the whole 500 of the next batch. } } diff --git a/packages/hub/src/lib/list-spaces.ts b/packages/hub/src/lib/list-spaces.ts index 2e23f2dfb..818e2ae9a 100644 --- a/packages/hub/src/lib/list-spaces.ts +++ b/packages/hub/src/lib/list-spaces.ts @@ -6,9 +6,27 @@ import { checkCredentials } from "../utils/checkCredentials"; import { parseLinkHeader } from "../utils/parseLinkHeader"; import { pick } from "../utils/pick"; -const EXPAND_KEYS = ["sdk", "likes", "private", "lastModified"]; +const EXPAND_KEYS = ["sdk", "likes", "private", "lastModified"] as const satisfies readonly (keyof ApiSpaceInfo)[]; +const EXPANDABLE_KEYS = [ + "author", + "cardData", + "datasets", + "disabled", + "gitalyUid", + "lastModified", + "createdAt", + "likes", + "private", + "runtime", + "sdk", + // "siblings", + "sha", + "subdomain", + "tags", + "models", +] as const satisfies readonly (keyof ApiSpaceInfo)[]; -export type SpaceEntry = { +export interface SpaceEntry { id: string; name: string; sdk?: SpaceSdk; @@ -16,11 +34,14 @@ export type SpaceEntry = { private: boolean; updatedAt: Date; // Use additionalFields to fetch the fields from ApiSpaceInfo -} & Partial>; +} -export async function* listSpaces(params?: { +export async function* listSpaces< + const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never, +>(params?: { search?: { owner?: string; + tags?: string[]; }; credentials?: Credentials; hubUrl?: string; @@ -31,11 +52,12 @@ export async function* listSpaces(params?: { /** * Additional fields to fetch from huggingface.co. */ - additionalFields?: Array; + additionalFields?: T[]; }): AsyncGenerator { checkCredentials(params?.credentials); const search = new URLSearchParams([ ...Object.entries({ limit: "500", ...(params?.search?.owner ? { author: params.search.owner } : undefined) }), + ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []), ...[...EXPAND_KEYS, ...(params?.additionalFields ?? [])].map((val) => ["expand", val] satisfies [string, string]), ]).toString(); let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/spaces?${search}`; diff --git a/packages/hub/src/types/api/api-dataset.d.ts b/packages/hub/src/types/api/api-dataset.d.ts index 3eb16b3ad..43b097853 100644 --- a/packages/hub/src/types/api/api-dataset.d.ts +++ b/packages/hub/src/types/api/api-dataset.d.ts @@ -1,3 +1,5 @@ +import type { License } from "../public"; + export interface ApiDatasetInfo { _id: string; id: string; @@ -5,7 +7,7 @@ export interface ApiDatasetInfo { author?: string; cardExists?: true; cardError?: unknown; - cardData?: unknown; + cardData?: ApiDatasetMetadata; contributors?: Array<{ user: string; _id: string }>; disabled: boolean; discussionsDisabled: boolean; @@ -17,6 +19,9 @@ export interface ApiDatasetInfo { likesRecent: number; private: boolean; updatedAt: string; // date + createdAt: string; // date + tags: string[]; + paperswithcode_id?: string; sha: string; files?: string[]; citation?: string; @@ -26,3 +31,59 @@ export interface ApiDatasetInfo { previewable?: boolean; doi?: { id: string; commit: string }; } + +export interface ApiDatasetMetadata { + licenses?: undefined; + license?: License | License[]; + license_name?: string; + license_link?: "LICENSE" | "LICENSE.md" | string; + license_details?: string; + languages?: undefined; + language?: string | string[]; + language_bcp47?: string[]; + language_details?: string; + tags?: string[]; + task_categories?: string[]; + task_ids?: string[]; + config_names?: string[]; + configs?: { + config_name: string; + data_files?: + | string + | string[] + | { + split: string; + path: string | string[]; + }[]; + data_dir?: string; + }[]; + benchmark?: string; + paperswithcode_id?: string | null; + pretty_name?: string; + viewer?: boolean; + viewer_display_urls?: boolean; + thumbnail?: string | null; + description?: string | null; + annotations_creators?: string[]; + language_creators?: string[]; + multilinguality?: string[]; + size_categories?: string[]; + source_datasets?: string[]; + extra_gated_prompt?: string; + extra_gated_fields?: { + /** + * "text" | "checkbox" | "date_picker" | "country" | "ip_location" | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } | { type: "select", options: Array } Property + */ + [x: string]: + | "text" + | "checkbox" + | "date_picker" + | "country" + | "ip_location" + | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } + | { type: "select"; options: Array }; + }; + extra_gated_heading?: string; + extra_gated_description?: string; + extra_gated_button_content?: string; +} diff --git a/packages/hub/src/types/api/api-model.d.ts b/packages/hub/src/types/api/api-model.d.ts index 410a07594..9cb8c6957 100644 --- a/packages/hub/src/types/api/api-model.d.ts +++ b/packages/hub/src/types/api/api-model.d.ts @@ -1,11 +1,12 @@ -import type { PipelineType } from "../public"; +import type { ModelLibraryKey, TransformersInfo } from "@huggingface/tasks"; +import type { License, PipelineType } from "../public"; export interface ApiModelInfo { _id: string; id: string; arxivIds: string[]; author?: string; - cardData: unknown; + cardData?: ApiModelMetadata; cardError: unknown; cardExists?: true; config: unknown; @@ -19,6 +20,7 @@ export interface ApiModelInfo { gitalyUid: string; lastAuthor: { email: string; user?: string }; lastModified: string; // convert to date + library_name?: ModelLibraryKey; likes: number; likesRecent: number; private: boolean; @@ -26,6 +28,91 @@ export interface ApiModelInfo { sha: string; spaces: string[]; updatedAt: string; // convert to date + createdAt: string; // convert to date pipeline_tag: PipelineType; + tags: string[]; "model-index": unknown; + safetensors?: { + parameters: Record; + total: number; + }; + transformersInfo?: TransformersInfo; +} + +export interface ApiModelMetadata { + datasets?: string | string[]; + license?: License | License[]; + license_name?: string; + license_link?: "LICENSE" | "LICENSE.md" | string; + license_details?: string; + inference?: + | boolean + | { + parameters?: { + aggregation_strategy?: string; + top_k?: number; + top_p?: number; + temperature?: number; + max_new_tokens?: number; + do_sample?: boolean; + negative_prompt?: string; + guidance_scale?: number; + num_inference_steps?: number; + }; + }; + language?: string | string[]; + language_bcp47?: string[]; + language_details?: string; + tags?: string[]; + pipeline_tag?: string; + co2_eq_emissions?: + | number + | { + /** + * Emissions in grams of CO2 + */ + emissions: number; + /** + * source of the information, either directly from AutoTrain, code carbon or from a scientific article documenting the model + */ + source?: string; + /** + * pre-training or fine-tuning + */ + training_type?: string; + /** + * as granular as possible, for instance Quebec, Canada or Brooklyn, NY, USA + */ + geographical_location?: string; + /** + * how much compute and what kind, e.g. 8 v100 GPUs + */ + hardware_used?: string; + }; + library_name?: string; + thumbnail?: string | null; + description?: string | null; + mask_token?: string; + widget?: WidgetExampleFromModelcard[]; + "model-index"?: ModelIndex[]; + finetuned_from?: string; + base_model?: string | string[]; + instance_prompt?: string | null; + extra_gated_prompt?: string; + extra_gated_fields?: { + /** + * "text" | "checkbox" | "date_picker" | "country" | "ip_location" | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } | { type: "select", options: Array } Property + */ + [x: string]: + | "text" + | "checkbox" + | "date_picker" + | "country" + | "ip_location" + | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } + | { type: "select"; options: Array }; + }; + extra_gated_heading?: string; + extra_gated_description?: string; + extra_gated_button_content?: string; } diff --git a/packages/hub/src/types/api/api-space.d.ts b/packages/hub/src/types/api/api-space.d.ts index 7a0a3b97c..151677f4b 100644 --- a/packages/hub/src/types/api/api-space.d.ts +++ b/packages/hub/src/types/api/api-space.d.ts @@ -1,4 +1,4 @@ -import type { SpaceRuntime, SpaceSdk } from "../public"; +import type { License, SpaceRuntime, SpaceSdk } from "../public"; type Color = "red" | "yellow" | "green" | "blue" | "indigo" | "purple" | "pink" | "gray"; @@ -22,6 +22,8 @@ export interface ApiSpaceInfo { likesRecent: number; private: boolean; updatedAt: string; // date + createdAt: string; // date + tags: string[]; sha: string; subdomain: string; title: string; @@ -36,3 +38,56 @@ export interface ApiSpaceInfo { datasets?: string[]; originSpace?: { _id: string; authorId: string }; } + +export interface ApiSpaceMetadata { + license?: License | License[]; + tags?: string[]; + title?: string; + colorFrom?: "red" | "yellow" | "green" | "blue" | "indigo" | "purple" | "pink" | "gray"; + colorTo?: "red" | "yellow" | "green" | "blue" | "indigo" | "purple" | "pink" | "gray"; + emoji?: string; + sdk?: "streamlit" | "gradio" | "docker" | "static"; + sdk_version?: string | string; + python_version?: string | string; + fullWidth?: boolean; + header?: "mini" | "default"; + app_file?: string; + app_port?: number; + base_path?: string; + models?: string[]; + datasets?: string[]; + pinned?: boolean; + metaTitle?: string; + description?: string; + thumbnail?: string; + /** + * If enabled, will associate an oauth app to the Space, adding variables and secrets to the Space's environment + */ + hf_oauth?: boolean; + /** + * The expiration of access tokens for your oauth app in minutes. max 30 days (43,200 minutes). Defaults to 8 hours (480 minutes) + */ + hf_oauth_expiration_minutes?: number; + /** + * OAuth scopes to request. By default you have access to the user's profile, you can request access to their repos or inference-api. + */ + hf_oauth_scopes?: ("email" | "read-repos" | "write-repos" | "manage-repos" | "inference-api")[]; + suggested_hardware?: + | "cpu-basic" + | "zero-a10g" + | "cpu-upgrade" + | "cpu-xl" + | "t4-small" + | "t4-medium" + | "a10g-small" + | "a10g-large" + | "a10g-largex2" + | "a10g-largex4" + | "a100-large"; + suggested_storage?: "small" | "medium" | "large"; + custom_headers?: { + "cross-origin-embedder-policy"?: "unsafe-none" | "require-corp" | "credentialless"; + "cross-origin-opener-policy"?: "same-origin" | "same-origin-allow-popups" | "unsafe-none"; + "cross-origin-resource-policy"?: "same-site" | "same-origin" | "cross-origin"; + }; +} diff --git a/packages/hub/src/types/public.d.ts b/packages/hub/src/types/public.d.ts index bab22fee6..3b2ec6126 100644 --- a/packages/hub/src/types/public.d.ts +++ b/packages/hub/src/types/public.d.ts @@ -79,3 +79,73 @@ export interface SpaceResourceConfig { throttled?: boolean; is_custom?: boolean; } + +export type License = + | "apache-2.0" + | "mit" + | "openrail" + | "bigscience-openrail-m" + | "creativeml-openrail-m" + | "bigscience-bloom-rail-1.0" + | "bigcode-openrail-m" + | "afl-3.0" + | "artistic-2.0" + | "bsl-1.0" + | "bsd" + | "bsd-2-clause" + | "bsd-3-clause" + | "bsd-3-clause-clear" + | "c-uda" + | "cc" + | "cc0-1.0" + | "cc-by-2.0" + | "cc-by-2.5" + | "cc-by-3.0" + | "cc-by-4.0" + | "cc-by-sa-3.0" + | "cc-by-sa-4.0" + | "cc-by-nc-2.0" + | "cc-by-nc-3.0" + | "cc-by-nc-4.0" + | "cc-by-nd-4.0" + | "cc-by-nc-nd-3.0" + | "cc-by-nc-nd-4.0" + | "cc-by-nc-sa-2.0" + | "cc-by-nc-sa-3.0" + | "cc-by-nc-sa-4.0" + | "cdla-sharing-1.0" + | "cdla-permissive-1.0" + | "cdla-permissive-2.0" + | "wtfpl" + | "ecl-2.0" + | "epl-1.0" + | "epl-2.0" + | "etalab-2.0" + | "eupl-1.1" + | "agpl-3.0" + | "gfdl" + | "gpl" + | "gpl-2.0" + | "gpl-3.0" + | "lgpl" + | "lgpl-2.1" + | "lgpl-3.0" + | "isc" + | "lppl-1.3c" + | "ms-pl" + | "mpl-2.0" + | "odc-by" + | "odbl" + | "openrail++" + | "osl-3.0" + | "postgresql" + | "ofl-1.1" + | "ncsa" + | "unlicense" + | "zlib" + | "pddl" + | "lgpl-lr" + | "deepfloyd-if-license" + | "llama2" + | "unknown" + | "other";