Skip to content

Commit

Permalink
add fal-ai as a provider
Browse files Browse the repository at this point in the history
  • Loading branch information
SBrandeis committed Jan 7, 2025
1 parent aa7d1ca commit bf47b3b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
20 changes: 13 additions & 7 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
Expand All @@ -9,7 +10,8 @@ import { isUrl } from "./isUrl";
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";

/**
* Loaded from huggingface.co/api/tasks if needed
* Lazy-loaded from huggingface.co/api/tasks when needed
* Used to determine the default model to use when it's not user defined
*/
let tasks: Record<string, { models: { id: string }[] }> | null = null;

Expand All @@ -36,7 +38,7 @@ export async function makeRequestOptions(

const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] = `Bearer ${accessToken}`;
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
}

if (!model && !tasks && taskHint) {
Expand Down Expand Up @@ -74,6 +76,9 @@ export async function makeRequestOptions(
case "together":
model = TOGETHER_MODEL_IDS[model]?.id ?? model;
break;
case "fal-ai":
model = FAL_AI_MODEL_IDS[model];
break;
default:
break;
}
Expand Down Expand Up @@ -120,8 +125,9 @@ export async function makeRequestOptions(
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
throw new Error("Inference proxying is not implemented yet");
} else {
/// This is an external key
switch (provider) {
case 'fal-ai':
return `${FAL_AI_API_BASE_URL}/${model}`;
case "replicate":
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
case "sambanova":
Expand Down Expand Up @@ -160,10 +166,10 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate"
? omit(otherArgs, "model")
: { ...otherArgs, model }),
}),
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai"
? omit(otherArgs, "model")
: { ...otherArgs, model }),
}),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
};
Expand Down
6 changes: 3 additions & 3 deletions packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { InferenceTask, Options, RequestArgs } from "../../types";
import { makeRequestOptions } from "../../lib/makeRequestOptions";

/**
* Primitive to make custom calls to Inference Endpoints
* Primitive to make custom calls to the inference provider
*/
export async function request<T>(
args: RequestArgs,
Expand Down Expand Up @@ -35,8 +35,8 @@ export async function request<T>(
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (output.error) {
throw new Error(JSON.stringify(output.error));
if (output.error || output.detail) {
throw new Error(JSON.stringify(output.error ?? output.detail));
} else {
throw new Error(output);
}
Expand Down
6 changes: 5 additions & 1 deletion packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ interface OutputUrlImageGeneration {
* Recommended model: stabilityai/stable-diffusion-2
*/
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
if (args.provider === "together") {
if (args.provider === "together" || args.provider === "fal-ai") {
args.prompt = args.inputs;
args.inputs = "";
args.response_format = "base64";
Expand All @@ -70,6 +70,10 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
taskHint: "text-to-image",
});
if (res && typeof res === "object") {
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
const image = await fetch(res.images[0].url);
return await image.blob();
}
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
const base64Data = res.data[0].b64_json;
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export interface Options {

export type InferenceTask = Exclude<PipelineType, "other">;

export const INFERENCE_PROVIDERS = ["replicate", "sambanova", "together", "hf-inference"] as const;
export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const;
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

export interface BaseArgs {
Expand Down

0 comments on commit bf47b3b

Please sign in to comment.