Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ [Widgets] Enable streaming in the conversational widget #486

Merged
merged 22 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 2 additions & 61 deletions packages/inference/src/tasks/nlp/textGeneration.ts
Original file line number Diff line number Diff line change
@@ -1,71 +1,12 @@
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks/src/tasks/text-generation/inference";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

export type TextGenerationArgs = BaseArgs & {
/**
* A string to be generated from
*/
inputs: string;
parameters?: {
/**
* (Optional: True). Bool. Whether or not to use sampling, use greedy decoding otherwise.
*/
do_sample?: boolean;
/**
* (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input length it is a estimate of the size of generated text you want. Each new tokens slows down the request, so look for balance between response times and length of text generated.
*/
max_new_tokens?: number;
/**
* (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens for best results.
*/
max_time?: number;
/**
* (Default: 1). Integer. The number of proposition you want to be returned.
*/
num_return_sequences?: number;
/**
* (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized to not be picked in successive generation passes.
*/
repetition_penalty?: number;
/**
* (Default: True). Bool. If set to False, the return results will not contain the original query making it easier for prompting.
*/
return_full_text?: boolean;
/**
* (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
*/
temperature?: number;
/**
* (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
*/
top_k?: number;
/**
* (Default: None). Float to define the tokens that are within the sample operation of text generation. Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p.
*/
top_p?: number;
/**
* (Default: None). Integer. The maximum number of tokens from the input.
*/
truncate?: number;
/**
* (Default: []) List of strings. The model will stop generating text when one of the strings in the list is generated.
* **/
stop_sequences?: string[];
};
};

export interface TextGenerationOutput {
/**
* The continuated string
*/
generated_text: string;
}

/**
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
*/
export async function textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationOutput> {
export async function textGeneration(args: BaseArgs & TextGenerationInput, options?: Options): Promise<TextGenerationOutput> {
const res = await request<TextGenerationOutput[]>(args, {
...options,
taskHint: "text-generation",
Expand Down
7 changes: 4 additions & 3 deletions packages/inference/src/tasks/nlp/textGenerationStream.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { Options } from "../../types";
import type { BaseArgs, Options } from "../../types";
import { streamingRequest } from "../custom/streamingRequest";
import type { TextGenerationArgs } from "./textGeneration";

import type { TextGenerationInput } from "@huggingface/tasks/src/tasks/text-generation/inference";

export interface TextGenerationStreamToken {
/** Token ID from the model tokenizer */
Expand Down Expand Up @@ -85,7 +86,7 @@ export interface TextGenerationStreamOutput {
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
*/
export async function* textGenerationStream(
args: TextGenerationArgs,
args: BaseArgs & TextGenerationInput,
options?: Options
): AsyncGenerator<TextGenerationStreamOutput> {
yield* streamingRequest<TextGenerationStreamOutput>(args, {
Copy link
Collaborator

@mishig25 mishig25 Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should TextGenerationStreamOutput live in @huggingface/tasks/src/tasks/text-generation/inference just like TextGenerationInput & TextGenerationOutput ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - cc @Wauplin with whom we discussed that previously

The current philosophy is to not type the streaming mode because it's transfer-specific, not inference-specific

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, we mentioned it here once: #468 (comment).

As a matter of fact, I came to the conclusion today that we should specify the stream parameter and the streamed output in our JS specs. I am currently starting to use the generated types in Python (see ongoing PR) and for now I've kept text_generation apart since I'm missing TextGenerationStreamResponse (defined here but I don't want to mix the auto-generated types with previous definitions). Agree it's more "transfer-specific" rather than "inference-specific" but the thing is that setting stream=True is modifying the output format so we need to document that somewhere.

Expand Down
3 changes: 2 additions & 1 deletion packages/widgets/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
],
"dependencies": {
"@huggingface/tasks": "workspace:^",
"@huggingface/jinja": "workspace:^"
"@huggingface/jinja": "workspace:^",
"@huggingface/inference": "workspace:^"
},
"peerDependencies": {
"svelte": "^3.59.2"
Expand Down
3 changes: 3 additions & 0 deletions packages/widgets/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import { updateWidgetState } from "../../stores.js";
import { tgiSupportedModels, updateWidgetState } from "../../stores.js";
import { TASKS_DATA } from "@huggingface/tasks";
import type { WidgetExample, WidgetExampleAttribute } from "@huggingface/tasks";
import type { WidgetProps, ExampleRunOpts } from "../types.js";
Expand Down Expand Up @@ -66,15 +66,23 @@
</div>
<div class="mb-0.5 flex w-full max-w-full flex-wrap items-center justify-between text-sm text-gray-500">
{#if pipeline && task}
<a
class={TASKS_DATA[task] ? "hover:underline" : undefined}
href={TASKS_DATA[task] ? `/tasks/${task}` : undefined}
target="_blank"
title={TASKS_DATA[task] ? `Learn more about ${task}` : undefined}
>
<PipelineTag classNames="mr-2 mb-1.5" {pipeline} />
</a>
<div class="flex gap-4 items-center mb-1.5">
<a
class={TASKS_DATA[task] ? "hover:underline" : undefined}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class={TASKS_DATA[task] ? "hover:underline" : undefined}
class:hover:underline={TASKS_DATA[task]}

didn't test. Can we use the new svelte syntax?

href={TASKS_DATA[task] ? `/tasks/${task}` : undefined}
target="_blank"
title={TASKS_DATA[task] ? `Learn more about ${task}` : undefined}
>
<PipelineTag {pipeline} />
</a>
{#if $tgiSupportedModels?.has(model.id)}
<p class="text-xs text-gray-400">
Streaming with <a href="https://huggingface.co/docs/text-generation-inference" class="underline">TGI</a>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Streaming with <a href="https://huggingface.co/docs/text-generation-inference" class="underline">TGI</a>
Streaming with <a href="https://huggingface.co/docs/text-generation-inference" class="hover:underline">TGI</a>

made it to hover:underline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

</p>
SBrandeis marked this conversation as resolved.
Show resolved Hide resolved
{/if}
</div>
{/if}

{#if validExamples.length && applyWidgetExample}
<WidgetExamples {validExamples} {isLoading} {applyWidgetExample} {callApiOnMount} {exampleQueryParams} />
{/if}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import WidgetHeader from "../WidgetHeader/WidgetHeader.svelte";
import WidgetInfo from "../WidgetInfo/WidgetInfo.svelte";
import IconCross from "../../..//Icons/IconCross.svelte";
import { getModelLoadInfo } from "../../..//InferenceWidget/shared/helpers.js";
import { modelLoadStates, widgetStates, updateWidgetState } from "../../stores.js";
import { getModelLoadInfo, getTgiSupportedModels } from "../../..//InferenceWidget/shared/helpers.js";
import { modelLoadStates, widgetStates, updateWidgetState, tgiSupportedModels } from "../../stores.js";

export let apiUrl: string;
export let model: WidgetProps["model"];
Expand All @@ -35,6 +35,10 @@
if (modelLoadInfo?.state === "TooBig") {
updateWidgetState(model.id, "isDisabled", true);
}

if (!$tgiSupportedModels) {
tgiSupportedModels.set(await getTgiSupportedModels(apiUrl));
SBrandeis marked this conversation as resolved.
Show resolved Hide resolved
}
SBrandeis marked this conversation as resolved.
Show resolved Hide resolved
})();
});
</script>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ export async function getModelLoadInfo(
}
}

export async function getTgiSupportedModels(url: string): Promise<Set<string> | undefined> {
const response = await fetch(`${url}/framework/text-generation-inference`);
const output = await response.json();
if (response.ok) {
return new Set(
(output as { model_id: string; task: string }[])
.filter(({ task }) => task === "text-generation")
.map(({ model_id }) => model_id)
)
} else {
console.warn(response.status, output.error);
return undefined;
}
}

// Extend requestBody with user supplied parameters for Inference Endpoints (serverless)
export function addInferenceParameters(requestBody: Record<string, unknown>, model: ModelData): void {
const inference = model?.cardData?.inference;
Expand Down
2 changes: 2 additions & 0 deletions packages/widgets/src/lib/components/InferenceWidget/stores.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export const widgetNoInference = writable<Record<ModelData["id"], boolean>>({});

export const widgetStates = writable<Record<ModelData["id"], WidgetState>>({});

export const tgiSupportedModels = writable<Set<string> | undefined>(undefined)

export function updateWidgetState(modelId: ModelData["id"], key: keyof WidgetState, val: boolean): void {
widgetStates.update((states) => {
// Check if the modelId exists, if not initialize it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import { Template } from "@huggingface/jinja";
import type { SpecialTokensMap, TokenizerConfig, WidgetExampleTextInput } from "@huggingface/tasks";
import { SPECIAL_TOKENS_ATTRIBUTES } from "@huggingface/tasks";
import { HfInference } from "@huggingface/inference";
import type { TextGenerationInput } from "@huggingface/tasks/src/tasks/text-generation/inference";

import WidgetOutputConvo from "../../shared/WidgetOutputConvo/WidgetOutputConvo.svelte";
import WidgetQuickInput from "../../shared/WidgetQuickInput/WidgetQuickInput.svelte";
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
import { addInferenceParameters, callInferenceApi, updateUrl } from "../../shared/helpers.js";
import { isTextInput } from "../../shared/inputValidation.js";
import { widgetStates } from "../../stores.js";
import { widgetStates, tgiSupportedModels } from "../../stores.js";

export let apiToken: WidgetProps["apiToken"];
export let apiUrl: WidgetProps["apiUrl"];
Expand All @@ -27,19 +29,15 @@
content: string;
}

let computeTime = "";
let messages: Message[] = [];
let error: string = "";
let isLoading = false;
let modelLoading = {
isLoading: false,
estimatedTime: 0,
};
let outputJson: string;
let text = "";

let compiledTemplate: Template;
let tokenizerConfig: TokenizerConfig;
let inferenceClient: HfInference | undefined = undefined;

// Check config and compile template
onMount(() => {
Expand All @@ -66,13 +64,20 @@
error = `Invalid chat template: "${(e as Error).message}"`;
return;
}

inferenceClient = new HfInference();
});

async function getOutput({ withModelLoading = false, isOnLoadCall = false }: InferenceRunOpts = {}) {
if (!compiledTemplate) {
return;
}

if (!inferenceClient) {
error = "Inference client not ready";
return;
}

const trimmedText = text.trim();
if (!trimmedText) {
return;
Expand Down Expand Up @@ -102,65 +107,47 @@
return;
}

const requestBody = {
const input: TextGenerationInput = {
inputs: chatText,
parameters: {
return_full_text: false,
max_new_tokens: 100,
},
};
addInferenceParameters(requestBody, model);
addInferenceParameters(input, model);

isLoading = true;

const res = await callInferenceApi(
apiUrl,
model.id,
requestBody,
apiToken,
(body) => parseOutput(body, messages),
withModelLoading,
includeCredentials,
isOnLoadCall
);

isLoading = false;
// Reset values
computeTime = "";
error = "";
modelLoading = { isLoading: false, estimatedTime: 0 };
outputJson = "";

if (res.status === "success") {
computeTime = res.computeTime;
outputJson = res.outputJson;
if (res.output) {
messages = res.output;
}
// Emptying input value
text = "";
} else if (res.status === "loading-model") {
modelLoading = {
isLoading: true,
estimatedTime: res.estimatedTime,
};
getOutput({ withModelLoading: true });
} else if (res.status === "error") {
error = res.error;
}
}

function parseOutput(body: unknown, chat: Message[]): Message[] {
if (Array.isArray(body) && body.length) {
const text = body[0]?.generated_text ?? "";

if (!text.length) {
throw new Error("Model did not generate a response.");
text = "";
try {
if ($tgiSupportedModels?.has(model.id)) {
console.debug("Starting text generation using the TGI streaming API");
let newMessage = {
role: "assistant",
content: "",
};
const previousMessages = [...messages];
const tokenStream = inferenceClient.textGenerationStream({
...input,
model: model.id,
accessToken: apiToken,
});
for await (const newToken of tokenStream) {
if (newToken.token.special) continue;
newMessage.content = newMessage.content + newToken.token.text;
messages = [...previousMessages, newMessage];
}
} else {
console.debug("Starting text generation using the synchronous API");
input.parameters.max_new_tokens = 100;
const output = await inferenceClient.textGeneration(
{ ...input, model: model.id, accessToken: apiToken },
{ includeCredentials, dont_load_model: !withModelLoading }
);
messages = [...messages, { role: "assistant", content: output.generated_text }];
}

return [...chat, { role: "assistant", content: text }];
} catch (e) {
error = `Something went wrong while requesting the Inference API: "${(e as Error).message}"`;
}
throw new TypeError("Invalid output: output must be of type Array & non-empty");
isLoading = false;
}

function extractSpecialTokensMap(tokenizerConfig: TokenizerConfig): SpecialTokensMap {
Expand Down Expand Up @@ -207,7 +194,7 @@
submitButtonLabel="Send"
/>

<WidgetInfo {model} {computeTime} {error} {modelLoading} />
<WidgetInfo {model} {error} />

<WidgetFooter {model} {isDisabled} {outputJson} />
</WidgetWrapper>
15 changes: 14 additions & 1 deletion packages/widgets/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,22 @@
},
},
{
id: "WizardLM/WizardLM-70B-V1.0",
id: "microsoft/phi-2",
pipeline_tag: "text-generation",
inference: InferenceDisplayability.Yes,
config: {
architectures: ["PhiForCausalLM"],
model_type: "phi",
auto_map: {
AutoConfig: "configuration_phi.PhiConfig",
AutoModelForCausalLM: "modeling_phi.PhiForCausalLM",
},
tokenizer: {
bos_token: "<|endoftext|>",
eos_token: "<|endoftext|>",
unk_token: "<|endoftext|>",
},
},
},
{
id: "openai/clip-vit-base-patch16",
Expand Down
Loading