From bea807ae66c936875c2f906622ee07114fe2e992 Mon Sep 17 00:00:00 2001 From: Simon Brandeis <33657802+SBrandeis@users.noreply.github.com> Date: Thu, 22 Feb 2024 17:14:00 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20[Widgets]=20Enable=20streaming=20in?= =?UTF-8?q?=20the=20conversational=20widget=20(#486)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Linked to #360 #410 Should unlock the conversational widget on Mistral if I'm not mistaken? # TL;DR - Leverage inference types from `@huggingface/task` to type input and output of the inference client - Use the inference client to call the inference serverless API - Use the streaming API when supported for the model --------- Co-authored-by: Mishig Co-authored-by: Victor Mustar --- .../inference/src/tasks/nlp/textGeneration.ts | 66 +---------- .../src/tasks/nlp/textGenerationStream.ts | 7 +- packages/tasks/src/tasks/index.ts | 58 +++++++++ packages/widgets/package.json | 5 +- packages/widgets/pnpm-lock.yaml | 103 ++++++++++++---- .../shared/WidgetHeader/WidgetHeader.svelte | 18 +-- .../WidgetOutputConvoBubble.svelte | 2 +- .../WidgetOutputConvo.svelte | 2 +- .../lib/components/InferenceWidget/stores.ts | 23 +++- .../ConversationalWidget.svelte | 112 +++++++++--------- packages/widgets/src/routes/+page.svelte | 15 ++- packages/widgets/tailwind.config.cjs | 3 + 12 files changed, 255 insertions(+), 159 deletions(-) diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts index 6b550a884..8452b1a17 100644 --- a/packages/inference/src/tasks/nlp/textGeneration.ts +++ b/packages/inference/src/tasks/nlp/textGeneration.ts @@ -1,71 +1,15 @@ +import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks/src/tasks/text-generation/inference"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; import { request } from "../custom/request"; -export type TextGenerationArgs = BaseArgs & { - /** - * A string to be generated from - */ - inputs: string; - parameters?: { - /** - * (Optional: True). Bool. Whether or not to use sampling, use greedy decoding otherwise. - */ - do_sample?: boolean; - /** - * (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input length it is a estimate of the size of generated text you want. Each new tokens slows down the request, so look for balance between response times and length of text generated. - */ - max_new_tokens?: number; - /** - * (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens for best results. - */ - max_time?: number; - /** - * (Default: 1). Integer. The number of proposition you want to be returned. - */ - num_return_sequences?: number; - /** - * (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized to not be picked in successive generation passes. - */ - repetition_penalty?: number; - /** - * (Default: True). Bool. If set to False, the return results will not contain the original query making it easier for prompting. - */ - return_full_text?: boolean; - /** - * (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability. - */ - temperature?: number; - /** - * (Default: None). Integer to define the top tokens considered within the sample operation to create new text. - */ - top_k?: number; - /** - * (Default: None). Float to define the tokens that are within the sample operation of text generation. Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p. - */ - top_p?: number; - /** - * (Default: None). Integer. The maximum number of tokens from the input. - */ - truncate?: number; - /** - * (Default: []) List of strings. The model will stop generating text when one of the strings in the list is generated. - * **/ - stop_sequences?: string[]; - }; -}; - -export interface TextGenerationOutput { - /** - * The continuated string - */ - generated_text: string; -} - /** * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with). */ -export async function textGeneration(args: TextGenerationArgs, options?: Options): Promise { +export async function textGeneration( + args: BaseArgs & TextGenerationInput, + options?: Options +): Promise { const res = await request(args, { ...options, taskHint: "text-generation", diff --git a/packages/inference/src/tasks/nlp/textGenerationStream.ts b/packages/inference/src/tasks/nlp/textGenerationStream.ts index 24c1ce8f3..99e1c2b7f 100644 --- a/packages/inference/src/tasks/nlp/textGenerationStream.ts +++ b/packages/inference/src/tasks/nlp/textGenerationStream.ts @@ -1,6 +1,7 @@ -import type { Options } from "../../types"; +import type { BaseArgs, Options } from "../../types"; import { streamingRequest } from "../custom/streamingRequest"; -import type { TextGenerationArgs } from "./textGeneration"; + +import type { TextGenerationInput } from "@huggingface/tasks/src/tasks/text-generation/inference"; export interface TextGenerationStreamToken { /** Token ID from the model tokenizer */ @@ -85,7 +86,7 @@ export interface TextGenerationStreamOutput { * Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time */ export async function* textGenerationStream( - args: TextGenerationArgs, + args: BaseArgs & TextGenerationInput, options?: Options ): AsyncGenerator { yield* streamingRequest(args, { diff --git a/packages/tasks/src/tasks/index.ts b/packages/tasks/src/tasks/index.ts index 6630851b6..d60ac55db 100644 --- a/packages/tasks/src/tasks/index.ts +++ b/packages/tasks/src/tasks/index.ts @@ -36,6 +36,64 @@ import zeroShotClassification from "./zero-shot-classification/data"; import zeroShotImageClassification from "./zero-shot-image-classification/data"; import zeroShotObjectDetection from "./zero-shot-object-detection/data"; +export type * from "./audio-classification/inference"; +export type * from "./automatic-speech-recognition/inference"; +export type * from "./document-question-answering/inference"; +export type * from "./feature-extraction/inference"; +export type * from "./fill-mask/inference"; +export type { + ImageClassificationInput, + ImageClassificationOutput, + ImageClassificationOutputElement, + ImageClassificationParameters, +} from "./image-classification/inference"; +export type * from "./image-to-image/inference"; +export type { ImageToTextInput, ImageToTextOutput, ImageToTextParameters } from "./image-to-text/inference"; +export type * from "./image-segmentation/inference"; +export type * from "./object-detection/inference"; +export type * from "./depth-estimation/inference"; +export type * from "./question-answering/inference"; +export type * from "./sentence-similarity/inference"; +export type * from "./summarization/inference"; +export type * from "./table-question-answering/inference"; +export type { TextToImageInput, TextToImageOutput, TextToImageParameters } from "./text-to-image/inference"; +export type { TextToAudioParameters, TextToSpeechInput, TextToSpeechOutput } from "./text-to-speech/inference"; +export type * from "./token-classification/inference"; +export type { + Text2TextGenerationParameters, + Text2TextGenerationTruncationStrategy, + TranslationInput, + TranslationOutput, +} from "./translation/inference"; +export type { + ClassificationOutputTransform, + TextClassificationInput, + TextClassificationOutput, + TextClassificationOutputElement, + TextClassificationParameters, +} from "./text-classification/inference"; +export type { + FinishReason, + PrefillToken, + TextGenerationInput, + TextGenerationOutput, + TextGenerationOutputDetails, + TextGenerationParameters, + TextGenerationSequenceDetails, + Token, +} from "./text-generation/inference"; +export type * from "./video-classification/inference"; +export type * from "./visual-question-answering/inference"; +export type * from "./zero-shot-classification/inference"; +export type * from "./zero-shot-image-classification/inference"; +export type { + BoundingBox, + ZeroShotObjectDetectionInput, + ZeroShotObjectDetectionInputData, + ZeroShotObjectDetectionOutput, + ZeroShotObjectDetectionOutputElement, +} from "./zero-shot-object-detection/inference"; + import type { ModelLibraryKey } from "../model-libraries"; /** diff --git a/packages/widgets/package.json b/packages/widgets/package.json index 4e75f6a67..751be18cf 100644 --- a/packages/widgets/package.json +++ b/packages/widgets/package.json @@ -46,7 +46,8 @@ ], "dependencies": { "@huggingface/tasks": "workspace:^", - "@huggingface/jinja": "workspace:^" + "@huggingface/jinja": "workspace:^", + "@huggingface/inference": "workspace:^" }, "peerDependencies": { "svelte": "^3.59.2" @@ -69,7 +70,7 @@ "svelte": "^3.59.2", "svelte-check": "^3.6.0", "svelte-preprocess": "^5.1.1", - "tailwindcss": "^3.3.5", + "tailwindcss": "^3.4.1", "tslib": "^2.4.1", "vite": "^4.5.0", "vite-plugin-dts": "^3.6.4" diff --git a/packages/widgets/pnpm-lock.yaml b/packages/widgets/pnpm-lock.yaml index 3eb4ebcbf..2d2fad743 100644 --- a/packages/widgets/pnpm-lock.yaml +++ b/packages/widgets/pnpm-lock.yaml @@ -5,6 +5,9 @@ settings: excludeLinksFromLockfile: false dependencies: + '@huggingface/inference': + specifier: workspace:^ + version: link:../inference '@huggingface/jinja': specifier: workspace:^ version: link:../jinja @@ -33,13 +36,13 @@ devDependencies: version: 1.27.4(svelte@3.59.2)(vite@4.5.0) '@sveltejs/package': specifier: ^2.0.0 - version: 2.0.0(svelte@3.59.2)(typescript@5.0.4) + version: 2.0.0(svelte@3.59.2)(typescript@5.3.3) '@sveltejs/vite-plugin-svelte': specifier: 2.5.3 version: 2.5.3(svelte@3.59.2)(vite@4.5.0) '@tailwindcss/forms': specifier: ^0.5.7 - version: 0.5.7(tailwindcss@3.3.5) + version: 0.5.7(tailwindcss@3.4.1) '@types/node': specifier: '20' version: 20.10.1 @@ -63,10 +66,10 @@ devDependencies: version: 3.6.0(postcss@8.4.31)(svelte@3.59.2) svelte-preprocess: specifier: ^5.1.1 - version: 5.1.1(postcss@8.4.31)(svelte@3.59.2)(typescript@5.0.4) + version: 5.1.1(postcss@8.4.31)(svelte@3.59.2)(typescript@5.3.3) tailwindcss: - specifier: ^3.3.5 - version: 3.3.5 + specifier: ^3.4.1 + version: 3.4.1 tslib: specifier: ^2.4.1 version: 2.4.1 @@ -75,7 +78,7 @@ devDependencies: version: 4.5.0(@types/node@20.10.1) vite-plugin-dts: specifier: ^3.6.4 - version: 3.6.4(@types/node@20.10.1)(typescript@5.0.4)(vite@4.5.0) + version: 3.6.4(@types/node@20.10.1)(typescript@5.3.3)(vite@4.5.0) packages: @@ -633,7 +636,7 @@ packages: - supports-color dev: true - /@sveltejs/package@2.0.0(svelte@3.59.2)(typescript@5.0.4): + /@sveltejs/package@2.0.0(svelte@3.59.2)(typescript@5.3.3): resolution: {integrity: sha512-sANz/dJibOHOe83hl8pFWUSypqefdYwPp6SUr0SmJxTNQFB5dDECEqwAwoy28DWCQFYl7DU+C1hKkTXyuKOdug==} engines: {node: ^16.14 || >=18} hasBin: true @@ -644,7 +647,7 @@ packages: kleur: 4.1.5 sade: 1.8.1 svelte: 3.59.2 - svelte2tsx: 0.6.25(svelte@3.59.2)(typescript@5.0.4) + svelte2tsx: 0.6.25(svelte@3.59.2)(typescript@5.3.3) transitivePeerDependencies: - typescript dev: true @@ -685,13 +688,13 @@ packages: - supports-color dev: true - /@tailwindcss/forms@0.5.7(tailwindcss@3.3.5): + /@tailwindcss/forms@0.5.7(tailwindcss@3.4.1): resolution: {integrity: sha512-QE7X69iQI+ZXwldE+rzasvbJiyV/ju1FGHH0Qn2W3FKbuYtqp8LKcy6iSw79fVUT5/Vvf+0XgLCeYVG+UV6hOw==} peerDependencies: tailwindcss: '>=3.0.0 || >= 3.0.0-alpha.1' dependencies: mini-svg-data-uri: 1.4.4 - tailwindcss: 3.3.5 + tailwindcss: 3.4.1 dev: true /@types/argparse@1.0.38: @@ -755,7 +758,7 @@ packages: '@vue/shared': 3.3.9 dev: true - /@vue/language-core@1.8.24(typescript@5.0.4): + /@vue/language-core@1.8.24(typescript@5.3.3): resolution: {integrity: sha512-2ClHvij0WlsDWryPzXJCSpPc6rusZFNoVtRZGgGGkKCmKuIREDDKmH8j+1tYyxPYyH0qL6pZ6+IHD8KIm5nWAw==} peerDependencies: typescript: '*' @@ -771,7 +774,7 @@ packages: minimatch: 9.0.3 muggle-string: 0.3.1 path-browserify: 1.0.1 - typescript: 5.0.4 + typescript: 5.3.3 vue-template-compiler: 2.7.15 dev: true @@ -2252,7 +2255,55 @@ packages: typescript: 5.0.4 dev: true - /svelte2tsx@0.6.25(svelte@3.59.2)(typescript@5.0.4): + /svelte-preprocess@5.1.1(postcss@8.4.31)(svelte@3.59.2)(typescript@5.3.3): + resolution: {integrity: sha512-p/Dp4hmrBW5mrCCq29lEMFpIJT2FZsRlouxEc5qpbOmXRbaFs7clLs8oKPwD3xCFyZfv1bIhvOzpQkhMEVQdMw==} + engines: {node: '>= 14.10.0'} + requiresBuild: true + peerDependencies: + '@babel/core': ^7.10.2 + coffeescript: ^2.5.1 + less: ^3.11.3 || ^4.0.0 + postcss: ^7 || ^8 + postcss-load-config: ^2.1.0 || ^3.0.0 || ^4.0.0 + pug: ^3.0.0 + sass: ^1.26.8 + stylus: ^0.55.0 + sugarss: ^2.0.0 || ^3.0.0 || ^4.0.0 + svelte: ^3.23.0 || ^4.0.0-next.0 || ^4.0.0 || ^5.0.0-next.0 + typescript: '>=3.9.5 || ^4.0.0 || ^5.0.0' + peerDependenciesMeta: + '@babel/core': + optional: true + coffeescript: + optional: true + less: + optional: true + postcss: + optional: true + postcss-load-config: + optional: true + pug: + optional: true + sass: + optional: true + stylus: + optional: true + sugarss: + optional: true + typescript: + optional: true + dependencies: + '@types/pug': 2.0.9 + detect-indent: 6.1.0 + magic-string: 0.27.0 + postcss: 8.4.31 + sorcery: 0.11.0 + strip-indent: 3.0.0 + svelte: 3.59.2 + typescript: 5.3.3 + dev: true + + /svelte2tsx@0.6.25(svelte@3.59.2)(typescript@5.3.3): resolution: {integrity: sha512-hhBKL5X9gGvKQAZ9xLoHnbE9Yb00HxEZJlxcj2drxWK+Tpqcs/bnodjSfCGbqEhvNaUXYNbVL7s4dEXT+o0f6w==} peerDependencies: svelte: ^3.55 || ^4.0.0-next.0 || ^4.0 || ^5.0.0-next.0 @@ -2261,7 +2312,7 @@ packages: dedent-js: 1.0.1 pascal-case: 3.1.2 svelte: 3.59.2 - typescript: 5.0.4 + typescript: 5.3.3 dev: true /svelte@3.59.2: @@ -2269,8 +2320,8 @@ packages: engines: {node: '>= 8'} dev: true - /tailwindcss@3.3.5: - resolution: {integrity: sha512-5SEZU4J7pxZgSkv7FP1zY8i2TIAOooNZ1e/OGtxIEv6GltpoiXUqWvLy89+a10qYTB1N5Ifkuw9lqQkN9sscvA==} + /tailwindcss@3.4.1: + resolution: {integrity: sha512-qAYmXRfk3ENzuPBakNK0SRrUDipP8NQnEY6772uDhflcQz5EhRdD7JNZxyrFHVQNCwULPBn6FNPp9brpO7ctcA==} engines: {node: '>=14.0.0'} hasBin: true dependencies: @@ -2367,6 +2418,12 @@ packages: hasBin: true dev: true + /typescript@5.3.3: + resolution: {integrity: sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==} + engines: {node: '>=14.17'} + hasBin: true + dev: true + /undici-types@5.26.5: resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} dev: true @@ -2409,7 +2466,7 @@ packages: engines: {node: '>= 0.10'} dev: true - /vite-plugin-dts@3.6.4(@types/node@20.10.1)(typescript@5.0.4)(vite@4.5.0): + /vite-plugin-dts@3.6.4(@types/node@20.10.1)(typescript@5.3.3)(vite@4.5.0): resolution: {integrity: sha512-yOVhUI/kQhtS6lCXRYYLv2UUf9bftcwQK9ROxCX2ul17poLQs02ctWX7+vXB8GPRzH8VCK3jebEFtPqqijXx6w==} engines: {node: ^14.18.0 || >=16.0.0} peerDependencies: @@ -2421,12 +2478,12 @@ packages: dependencies: '@microsoft/api-extractor': 7.38.3(@types/node@20.10.1) '@rollup/pluginutils': 5.0.5(rollup@3.29.4) - '@vue/language-core': 1.8.24(typescript@5.0.4) + '@vue/language-core': 1.8.24(typescript@5.3.3) debug: 4.3.4 kolorist: 1.8.0 - typescript: 5.0.4 + typescript: 5.3.3 vite: 4.5.0(@types/node@20.10.1) - vue-tsc: 1.8.24(typescript@5.0.4) + vue-tsc: 1.8.24(typescript@5.3.3) transitivePeerDependencies: - '@types/node' - rollup @@ -2487,16 +2544,16 @@ packages: he: 1.2.0 dev: true - /vue-tsc@1.8.24(typescript@5.0.4): + /vue-tsc@1.8.24(typescript@5.3.3): resolution: {integrity: sha512-eH1CSj231OzVEY5Hi7wS6ubzyOEwgr5jCptR0Ddf2SitGcaXIsPVDvrprm3eolCdyhDt3WS1Eb2F4fGX9BsUUw==} hasBin: true peerDependencies: typescript: '*' dependencies: '@volar/typescript': 1.11.1 - '@vue/language-core': 1.8.24(typescript@5.0.4) + '@vue/language-core': 1.8.24(typescript@5.3.3) semver: 7.5.4 - typescript: 5.0.4 + typescript: 5.3.3 dev: true /which@2.0.2: diff --git a/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetHeader/WidgetHeader.svelte b/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetHeader/WidgetHeader.svelte index 2d9e18376..1b10f4873 100644 --- a/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetHeader/WidgetHeader.svelte +++ b/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetHeader/WidgetHeader.svelte @@ -66,15 +66,17 @@
{#if pipeline && task} - - - +
+ + + +
{/if} + {#if validExamples.length && applyWidgetExample} {/if} diff --git a/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOuputConvoBubble/WidgetOutputConvoBubble.svelte b/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOuputConvoBubble/WidgetOutputConvoBubble.svelte index 31ab4efa5..7dfbd3770 100644 --- a/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOuputConvoBubble/WidgetOutputConvoBubble.svelte +++ b/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOuputConvoBubble/WidgetOutputConvoBubble.svelte @@ -4,7 +4,7 @@
diff --git a/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOutputConvo/WidgetOutputConvo.svelte b/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOutputConvo/WidgetOutputConvo.svelte index 19bb9987e..d8b7ce5f5 100644 --- a/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOutputConvo/WidgetOutputConvo.svelte +++ b/packages/widgets/src/lib/components/InferenceWidget/shared/WidgetOutputConvo/WidgetOutputConvo.svelte @@ -18,7 +18,7 @@
-
+
Input a message to start chatting with {modelId}.
diff --git a/packages/widgets/src/lib/components/InferenceWidget/stores.ts b/packages/widgets/src/lib/components/InferenceWidget/stores.ts index 09a332991..ba5340e57 100644 --- a/packages/widgets/src/lib/components/InferenceWidget/stores.ts +++ b/packages/widgets/src/lib/components/InferenceWidget/stores.ts @@ -1,4 +1,4 @@ -import { writable } from "svelte/store"; +import { get, writable } from "svelte/store"; import type { ModelData } from "@huggingface/tasks"; import type { ModelLoadInfo, WidgetState } from "./shared/types.js"; @@ -8,6 +8,27 @@ export const widgetNoInference = writable>({}); export const widgetStates = writable>({}); +const tgiSupportedModels = writable | undefined>(undefined); + +export async function getTgiSupportedModels(url: string): Promise { + if (!get(tgiSupportedModels)) { + const response = await fetch(`${url}/framework/text-generation-inference`); + const output = await response.json(); + if (response.ok) { + tgiSupportedModels.set( + new Set( + (output as { model_id: string; task: string }[]) + .filter(({ task }) => task === "text-generation") + .map(({ model_id }) => model_id) + ) + ); + } else { + console.warn(response.status, output.error); + } + } + return tgiSupportedModels; +} + export function updateWidgetState(modelId: ModelData["id"], key: keyof WidgetState, val: boolean): void { widgetStates.update((states) => { // Check if the modelId exists, if not initialize it diff --git a/packages/widgets/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte b/packages/widgets/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte index 21eeaf6d7..71ae946b1 100644 --- a/packages/widgets/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte +++ b/packages/widgets/src/lib/components/InferenceWidget/widgets/ConversationalWidget/ConversationalWidget.svelte @@ -2,15 +2,23 @@ import { onMount } from "svelte"; import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types.js"; import { Template } from "@huggingface/jinja"; - import type { SpecialTokensMap, TokenizerConfig, WidgetExampleTextInput } from "@huggingface/tasks"; + import type { + SpecialTokensMap, + TokenizerConfig, + WidgetExampleTextInput, + TextGenerationInput, + } from "@huggingface/tasks"; import { SPECIAL_TOKENS_ATTRIBUTES } from "@huggingface/tasks"; + import { HfInference } from "@huggingface/inference"; + import type { ConversationMessage } from "../../shared/types.js"; import WidgetOutputConvo from "../../shared/WidgetOutputConvo/WidgetOutputConvo.svelte"; import WidgetQuickInput from "../../shared/WidgetQuickInput/WidgetQuickInput.svelte"; import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte"; import { addInferenceParameters, callInferenceApi, updateUrl } from "../../shared/helpers.js"; import { isTextInput } from "../../shared/inputValidation.js"; - import { widgetStates } from "../../stores.js"; + import { widgetStates, getTgiSupportedModels } from "../../stores.js"; + import type { Writable } from "svelte/store"; export let apiToken: WidgetProps["apiToken"]; export let apiUrl: WidgetProps["apiUrl"]; @@ -20,24 +28,23 @@ export let shouldUpdateUrl: WidgetProps["shouldUpdateUrl"]; export let includeCredentials: WidgetProps["includeCredentials"]; + let tgiSupportedModels: Writable | undefined>; + $: isDisabled = $widgetStates?.[model.id]?.isDisabled; - let computeTime = ""; let messages: ConversationMessage[] = []; let error: string = ""; let isLoading = false; - let modelLoading = { - isLoading: false, - estimatedTime: 0, - }; let outputJson: string; let text = ""; let compiledTemplate: Template; let tokenizerConfig: TokenizerConfig; + let inferenceClient: HfInference | undefined = undefined; // Check config and compile template onMount(() => { + getTgiSupportedModels(apiUrl).then((store) => (tgiSupportedModels = store)); const config = model.config; if (config === undefined) { error = "Model config not found"; @@ -61,6 +68,8 @@ error = `Invalid chat template: "${(e as Error).message}"`; return; } + + inferenceClient = new HfInference(); }); async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) { @@ -68,6 +77,11 @@ return; } + if (!inferenceClient) { + error = "Inference client not ready"; + return; + } + const trimmedText = text.trim(); if (!trimmedText) { return; @@ -97,65 +111,47 @@ return; } - const requestBody = { + const input: TextGenerationInput & Required> = { inputs: chatText, parameters: { return_full_text: false, - max_new_tokens: 100, }, }; - addInferenceParameters(requestBody, model); + addInferenceParameters(input, model); isLoading = true; - - const res = await callInferenceApi( - apiUrl, - model.id, - requestBody, - apiToken, - (body) => parseOutput(body, messages), - withModelLoading, - includeCredentials, - isOnLoadCall - ); - - isLoading = false; - // Reset values - computeTime = ""; - error = ""; - modelLoading = { isLoading: false, estimatedTime: 0 }; - outputJson = ""; - - if (res.status === "success") { - computeTime = res.computeTime; - outputJson = res.outputJson; - if (res.output) { - messages = res.output; - } - // Emptying input value - text = ""; - } else if (res.status === "loading-model") { - modelLoading = { - isLoading: true, - estimatedTime: res.estimatedTime, - }; - getOutput({ withModelLoading: true }); - } else if (res.status === "error") { - error = res.error; - } - } - - function parseOutput(body: unknown, chat: ConversationMessage[]): ConversationMessage[] { - if (Array.isArray(body) && body.length) { - const text = body[0]?.generated_text ?? ""; - - if (!text.length) { - throw new Error("Model did not generate a response."); + text = ""; + try { + if ($tgiSupportedModels?.has(model.id)) { + console.debug("Starting text generation using the TGI streaming API"); + let newMessage = { + role: "assistant", + content: "", + } satisfies ConversationMessage; + const previousMessages = [...messages]; + const tokenStream = inferenceClient.textGenerationStream({ + ...input, + model: model.id, + accessToken: apiToken, + }); + for await (const newToken of tokenStream) { + if (newToken.token.special) continue; + newMessage.content = newMessage.content + newToken.token.text; + messages = [...previousMessages, newMessage]; + } + } else { + console.debug("Starting text generation using the synchronous API"); + input.parameters.max_new_tokens = 100; + const output = await inferenceClient.textGeneration( + { ...input, model: model.id, accessToken: apiToken }, + { includeCredentials, dont_load_model: !withModelLoading } + ); + messages = [...messages, { role: "assistant", content: output.generated_text }]; } - - return [...chat, { role: "assistant", content: text }]; + } catch (e) { + error = `Something went wrong while requesting the Inference API: "${(e as Error).message}"`; } - throw new TypeError("Invalid output: output must be of type Array & non-empty"); + isLoading = false; } function extractSpecialTokensMap(tokenizerConfig: TokenizerConfig): SpecialTokensMap { @@ -202,7 +198,7 @@ submitButtonLabel="Send" /> - + diff --git a/packages/widgets/src/routes/+page.svelte b/packages/widgets/src/routes/+page.svelte index 17ab4517b..389ea9ab1 100644 --- a/packages/widgets/src/routes/+page.svelte +++ b/packages/widgets/src/routes/+page.svelte @@ -48,9 +48,22 @@ }, }, { - id: "WizardLM/WizardLM-70B-V1.0", + id: "microsoft/phi-2", pipeline_tag: "text-generation", inference: InferenceDisplayability.Yes, + config: { + architectures: ["PhiForCausalLM"], + model_type: "phi", + auto_map: { + AutoConfig: "configuration_phi.PhiConfig", + AutoModelForCausalLM: "modeling_phi.PhiForCausalLM", + }, + tokenizer_config: { + bos_token: "<|endoftext|>", + eos_token: "<|endoftext|>", + unk_token: "<|endoftext|>", + }, + }, }, { id: "openai/clip-vit-base-patch16", diff --git a/packages/widgets/tailwind.config.cjs b/packages/widgets/tailwind.config.cjs index 6853e61fa..10f977a9e 100644 --- a/packages/widgets/tailwind.config.cjs +++ b/packages/widgets/tailwind.config.cjs @@ -37,6 +37,9 @@ module.exports = { sans: ["Source Sans Pro", ...defaultTheme.fontFamily.sans], mono: ["IBM Plex Mono", ...defaultTheme.fontFamily.mono], }, + fontSize: { + smd: "0.94rem", + }, }, }, plugins: [require("@tailwindcss/forms")],