diff --git a/app/api/alibaba.ts b/app/api/alibaba.ts index 894b1ae4c04..20f6caefa8d 100644 --- a/app/api/alibaba.ts +++ b/app/api/alibaba.ts @@ -8,7 +8,7 @@ import { import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; -import { isModelAvailableInServer } from "@/app/utils/model"; +import { isModelNotavailableInServer } from "@/app/utils/model"; const serverConfig = getServerSideConfig(); @@ -89,7 +89,7 @@ async function request(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, ServiceProvider.Alibaba as string, diff --git a/app/api/anthropic.ts b/app/api/anthropic.ts index 7a44443710f..b96637b2c8c 100644 --- a/app/api/anthropic.ts +++ b/app/api/anthropic.ts @@ -9,7 +9,7 @@ import { import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "./auth"; -import { isModelAvailableInServer } from "@/app/utils/model"; +import { isModelNotavailableInServer } from "@/app/utils/model"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]); @@ -122,7 +122,7 @@ async function request(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, ServiceProvider.Anthropic as string, diff --git a/app/api/baidu.ts b/app/api/baidu.ts index 0408b43c5bc..0f4e05ee86c 100644 --- a/app/api/baidu.ts +++ b/app/api/baidu.ts @@ -8,7 +8,7 @@ import { import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; -import { isModelAvailableInServer } from "@/app/utils/model"; +import { isModelNotavailableInServer } from "@/app/utils/model"; import { getAccessToken } from "@/app/utils/baidu"; const serverConfig = getServerSideConfig(); @@ -104,7 +104,7 @@ async function request(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, ServiceProvider.Baidu as string, diff --git a/app/api/bytedance.ts b/app/api/bytedance.ts index cb65b106109..51b39ceb7cb 100644 --- a/app/api/bytedance.ts +++ b/app/api/bytedance.ts @@ -8,7 +8,7 @@ import { import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; -import { isModelAvailableInServer } from "@/app/utils/model"; +import { isModelNotavailableInServer } from "@/app/utils/model"; const serverConfig = getServerSideConfig(); @@ -88,7 +88,7 @@ async function request(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, ServiceProvider.ByteDance as string, diff --git a/app/api/common.ts b/app/api/common.ts index 495a12ccdbb..8b75d4aedf6 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -2,7 +2,7 @@ import { NextRequest, NextResponse } from "next/server"; import { getServerSideConfig } from "../config/server"; import { OPENAI_BASE_URL, ServiceProvider } from "../constant"; import { cloudflareAIGatewayUrl } from "../utils/cloudflare"; -import { getModelProvider, isModelAvailableInServer } from "../utils/model"; +import { getModelProvider, isModelNotavailableInServer } from "../utils/model"; const serverConfig = getServerSideConfig(); @@ -118,15 +118,14 @@ export async function requestOpenai(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, - ServiceProvider.OpenAI as string, - ) || - isModelAvailableInServer( - serverConfig.customModels, - jsonBody?.model as string, - ServiceProvider.Azure as string, + [ + ServiceProvider.OpenAI, + ServiceProvider.Azure, + jsonBody?.model as string, // support provider-unspecified model + ], ) ) { return NextResponse.json( diff --git a/app/api/glm.ts b/app/api/glm.ts index 3625b9f7bf9..8431c5db5b0 100644 --- a/app/api/glm.ts +++ b/app/api/glm.ts @@ -8,7 +8,7 @@ import { import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; -import { isModelAvailableInServer } from "@/app/utils/model"; +import { isModelNotavailableInServer } from "@/app/utils/model"; const serverConfig = getServerSideConfig(); @@ -89,7 +89,7 @@ async function request(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, ServiceProvider.ChatGLM as string, diff --git a/app/api/iflytek.ts b/app/api/iflytek.ts index 8b8227dce1f..6624f74e9ab 100644 --- a/app/api/iflytek.ts +++ b/app/api/iflytek.ts @@ -8,7 +8,7 @@ import { import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; -import { isModelAvailableInServer } from "@/app/utils/model"; +import { isModelNotavailableInServer } from "@/app/utils/model"; // iflytek const serverConfig = getServerSideConfig(); @@ -89,7 +89,7 @@ async function request(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, ServiceProvider.Iflytek as string, diff --git a/app/api/moonshot.ts b/app/api/moonshot.ts index 5bf4807e3e6..792d14d3334 100644 --- a/app/api/moonshot.ts +++ b/app/api/moonshot.ts @@ -8,7 +8,7 @@ import { import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; -import { isModelAvailableInServer } from "@/app/utils/model"; +import { isModelNotavailableInServer } from "@/app/utils/model"; const serverConfig = getServerSideConfig(); @@ -88,7 +88,7 @@ async function request(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, ServiceProvider.Moonshot as string, diff --git a/app/api/xai.ts b/app/api/xai.ts index a4ee8b39731..4aad5e5fb3e 100644 --- a/app/api/xai.ts +++ b/app/api/xai.ts @@ -8,7 +8,7 @@ import { import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "@/app/api/auth"; -import { isModelAvailableInServer } from "@/app/utils/model"; +import { isModelNotavailableInServer } from "@/app/utils/model"; const serverConfig = getServerSideConfig(); @@ -88,7 +88,7 @@ async function request(req: NextRequest) { // not undefined and is false if ( - isModelAvailableInServer( + isModelNotavailableInServer( serverConfig.customModels, jsonBody?.model as string, ServiceProvider.XAI as string, diff --git a/app/client/platforms/glm.ts b/app/client/platforms/glm.ts index a7965947fab..a8d1869e30e 100644 --- a/app/client/platforms/glm.ts +++ b/app/client/platforms/glm.ts @@ -21,16 +21,108 @@ import { SpeechOptions, } from "../api"; import { getClientConfig } from "@/app/config/client"; -import { getMessageTextContent } from "@/app/utils"; +import { getMessageTextContent, isVisionModel } from "@/app/utils"; import { RequestPayload } from "./openai"; import { fetch } from "@/app/utils/stream"; +import { preProcessImageContent } from "@/app/utils/chat"; + +interface BasePayload { + model: string; +} + +interface ChatPayload extends BasePayload { + messages: ChatOptions["messages"]; + stream?: boolean; + temperature?: number; + presence_penalty?: number; + frequency_penalty?: number; + top_p?: number; +} + +interface ImageGenerationPayload extends BasePayload { + prompt: string; + size?: string; + user_id?: string; +} + +interface VideoGenerationPayload extends BasePayload { + prompt: string; + duration?: number; + resolution?: string; + user_id?: string; +} + +type ModelType = "chat" | "image" | "video"; export class ChatGLMApi implements LLMApi { private disableListModels = true; + private getModelType(model: string): ModelType { + if (model.startsWith("cogview-")) return "image"; + if (model.startsWith("cogvideo-")) return "video"; + return "chat"; + } + + private getModelPath(type: ModelType): string { + switch (type) { + case "image": + return ChatGLM.ImagePath; + case "video": + return ChatGLM.VideoPath; + default: + return ChatGLM.ChatPath; + } + } + + private createPayload( + messages: ChatOptions["messages"], + modelConfig: any, + options: ChatOptions, + ): BasePayload { + const modelType = this.getModelType(modelConfig.model); + const lastMessage = messages[messages.length - 1]; + const prompt = + typeof lastMessage.content === "string" + ? lastMessage.content + : lastMessage.content.map((c) => c.text).join("\n"); + + switch (modelType) { + case "image": + return { + model: modelConfig.model, + prompt, + size: options.config.size, + } as ImageGenerationPayload; + default: + return { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + } as ChatPayload; + } + } + + private parseResponse(modelType: ModelType, json: any): string { + switch (modelType) { + case "image": { + const imageUrl = json.data?.[0]?.url; + return imageUrl ? `![Generated Image](${imageUrl})` : ""; + } + case "video": { + const videoUrl = json.data?.[0]?.url; + return videoUrl ? `` : ""; + } + default: + return this.extractMessage(json); + } + } + path(path: string): string { const accessStore = useAccessStore.getState(); - let baseUrl = ""; if (accessStore.useCustomConfig) { @@ -51,7 +143,6 @@ export class ChatGLMApi implements LLMApi { } console.log("[Proxy Endpoint] ", baseUrl, path); - return [baseUrl, path].join("/"); } @@ -64,9 +155,12 @@ export class ChatGLMApi implements LLMApi { } async chat(options: ChatOptions) { + const visionModel = isVisionModel(options.config.model); const messages: ChatOptions["messages"] = []; for (const v of options.messages) { - const content = getMessageTextContent(v); + const content = visionModel + ? await preProcessImageContent(v.content) + : getMessageTextContent(v); messages.push({ role: v.role, content }); } @@ -78,25 +172,16 @@ export class ChatGLMApi implements LLMApi { providerName: options.config.providerName, }, }; + const modelType = this.getModelType(modelConfig.model); + const requestPayload = this.createPayload(messages, modelConfig, options); + const path = this.path(this.getModelPath(modelType)); - const requestPayload: RequestPayload = { - messages, - stream: options.config.stream, - model: modelConfig.model, - temperature: modelConfig.temperature, - presence_penalty: modelConfig.presence_penalty, - frequency_penalty: modelConfig.frequency_penalty, - top_p: modelConfig.top_p, - }; + console.log(`[Request] glm ${modelType} payload: `, requestPayload); - console.log("[Request] glm payload: ", requestPayload); - - const shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); try { - const chatPath = this.path(ChatGLM.ChatPath); const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), @@ -104,12 +189,23 @@ export class ChatGLMApi implements LLMApi { headers: getHeaders(), }; - // make a fetch request const requestTimeoutId = setTimeout( () => controller.abort(), REQUEST_TIMEOUT_MS, ); + if (modelType === "image" || modelType === "video") { + const res = await fetch(path, chatPayload); + clearTimeout(requestTimeoutId); + + const resJson = await res.json(); + console.log(`[Response] glm ${modelType}:`, resJson); + const message = this.parseResponse(modelType, resJson); + options.onFinish(message, res); + return; + } + + const shouldStream = !!options.config.stream; if (shouldStream) { const [tools, funcs] = usePluginStore .getState() @@ -117,7 +213,7 @@ export class ChatGLMApi implements LLMApi { useChatStore.getState().currentSession().mask?.plugin || [], ); return stream( - chatPath, + path, requestPayload, getHeaders(), tools as any, @@ -125,7 +221,6 @@ export class ChatGLMApi implements LLMApi { controller, // parseSSE (text: string, runTools: ChatMessageTool[]) => { - // console.log("parseSSE", text, runTools); const json = JSON.parse(text); const choices = json.choices as Array<{ delta: { @@ -154,7 +249,7 @@ export class ChatGLMApi implements LLMApi { } return choices[0]?.delta?.content; }, - // processToolMessage, include tool_calls message and tool call results + // processToolMessage ( requestPayload: RequestPayload, toolCallMessage: any, @@ -172,7 +267,7 @@ export class ChatGLMApi implements LLMApi { options, ); } else { - const res = await fetch(chatPath, chatPayload); + const res = await fetch(path, chatPayload); clearTimeout(requestTimeoutId); const resJson = await res.json(); @@ -184,6 +279,7 @@ export class ChatGLMApi implements LLMApi { options.onError?.(e as Error); } } + async usage() { return { used: 0, diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index a7bce4fc2d0..5ca8e1071a7 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -60,9 +60,18 @@ export class GeminiProApi implements LLMApi { extractMessage(res: any) { console.log("[Response] gemini-pro response: ", res); + const getTextFromParts = (parts: any[]) => { + if (!Array.isArray(parts)) return ""; + + return parts + .map((part) => part?.text || "") + .filter((text) => text.trim() !== "") + .join("\n\n"); + }; + return ( - res?.candidates?.at(0)?.content?.parts.at(0)?.text || - res?.at(0)?.candidates?.at(0)?.content?.parts.at(0)?.text || + getTextFromParts(res?.candidates?.at(0)?.content?.parts) || + getTextFromParts(res?.at(0)?.candidates?.at(0)?.content?.parts) || res?.error?.message || "" ); @@ -223,7 +232,10 @@ export class GeminiProApi implements LLMApi { }, }); } - return chunkJson?.candidates?.at(0)?.content.parts.at(0)?.text; + return chunkJson?.candidates + ?.at(0) + ?.content.parts?.map((part: { text: string }) => part.text) + .join("\n\n"); }, // processToolMessage, include tool_calls message and tool call results ( diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 15cfb7ca602..5a110b84bea 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -24,7 +24,7 @@ import { stream, } from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; -import { DalleSize, DalleQuality, DalleStyle } from "@/app/typing"; +import { ModelSize, DalleQuality, DalleStyle } from "@/app/typing"; import { ChatOptions, @@ -73,7 +73,7 @@ export interface DalleRequestPayload { prompt: string; response_format: "url" | "b64_json"; n: number; - size: DalleSize; + size: ModelSize; quality: DalleQuality; style: DalleStyle; } diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 51fe74fe7be..f34f7d78e09 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -72,6 +72,8 @@ import { isDalle3, showPlugins, safeLocalStorage, + getModelSizes, + supportsCustomSize, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -79,7 +81,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import dynamic from "next/dynamic"; import { ChatControllerPool } from "../client/controller"; -import { DalleSize, DalleQuality, DalleStyle } from "../typing"; +import { DalleQuality, DalleStyle, ModelSize } from "../typing"; import { Prompt, usePromptStore } from "../store/prompt"; import Locale from "../locales"; @@ -519,10 +521,11 @@ export function ChatActions(props: { const [showSizeSelector, setShowSizeSelector] = useState(false); const [showQualitySelector, setShowQualitySelector] = useState(false); const [showStyleSelector, setShowStyleSelector] = useState(false); - const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; + const modelSizes = getModelSizes(currentModel); const dalle3Qualitys: DalleQuality[] = ["standard", "hd"]; const dalle3Styles: DalleStyle[] = ["vivid", "natural"]; - const currentSize = session.mask.modelConfig?.size ?? "1024x1024"; + const currentSize = + session.mask.modelConfig?.size ?? ("1024x1024" as ModelSize); const currentQuality = session.mask.modelConfig?.quality ?? "standard"; const currentStyle = session.mask.modelConfig?.style ?? "vivid"; @@ -673,7 +676,7 @@ export function ChatActions(props: { /> )} - {isDalle3(currentModel) && ( + {supportsCustomSize(currentModel) && ( setShowSizeSelector(true)} text={currentSize} @@ -684,7 +687,7 @@ export function ChatActions(props: { {showSizeSelector && ( ({ + items={modelSizes.map((m) => ({ title: m, value: m, }))} diff --git a/app/components/sidebar.tsx b/app/components/sidebar.tsx index a5e33b15ea3..fa4caee0d9d 100644 --- a/app/components/sidebar.tsx +++ b/app/components/sidebar.tsx @@ -22,7 +22,6 @@ import { MIN_SIDEBAR_WIDTH, NARROW_SIDEBAR_WIDTH, Path, - PLUGINS, REPO_URL, } from "../constant"; @@ -32,6 +31,12 @@ import dynamic from "next/dynamic"; import { showConfirm, Selector } from "./ui-lib"; import clsx from "clsx"; +const DISCOVERY = [ + { name: Locale.Plugin.Name, path: Path.Plugins }, + { name: "Stable Diffusion", path: Path.Sd }, + { name: Locale.SearchChat.Page.Title, path: Path.SearchChat }, +]; + const ChatList = dynamic(async () => (await import("./chat-list")).ChatList, { loading: () => null, }); @@ -219,7 +224,7 @@ export function SideBarTail(props: { export function SideBar(props: { className?: string }) { useHotKey(); const { onDragStart, shouldNarrow } = useDragSideBar(); - const [showPluginSelector, setShowPluginSelector] = useState(false); + const [showDiscoverySelector, setshowDiscoverySelector] = useState(false); const navigate = useNavigate(); const config = useAppConfig(); const chatStore = useChatStore(); @@ -254,21 +259,21 @@ export function SideBar(props: { className?: string }) { icon={} text={shouldNarrow ? undefined : Locale.Discovery.Name} className={styles["sidebar-bar-button"]} - onClick={() => setShowPluginSelector(true)} + onClick={() => setshowDiscoverySelector(true)} shadow /> - {showPluginSelector && ( + {showDiscoverySelector && ( { + ...DISCOVERY.map((item) => { return { title: item.name, value: item.path, }; }), ]} - onClose={() => setShowPluginSelector(false)} + onClose={() => setshowDiscoverySelector(false)} onSelection={(s) => { navigate(s[0], { state: { fromHome: true } }); }} diff --git a/app/config/server.ts b/app/config/server.ts index 9d6b3c2b8da..bd88082169a 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -1,5 +1,6 @@ import md5 from "spark-md5"; import { DEFAULT_MODELS, DEFAULT_GA_ID } from "../constant"; +import { isGPT4Model } from "../utils/model"; declare global { namespace NodeJS { @@ -127,20 +128,12 @@ export const getServerSideConfig = () => { if (disableGPT4) { if (customModels) customModels += ","; - customModels += DEFAULT_MODELS.filter( - (m) => - (m.name.startsWith("gpt-4") || m.name.startsWith("chatgpt-4o") || m.name.startsWith("o1")) && - !m.name.startsWith("gpt-4o-mini"), - ) + customModels += DEFAULT_MODELS.filter((m) => isGPT4Model(m.name)) .map((m) => "-" + m.name) .join(","); - if ( - (defaultModel.startsWith("gpt-4") || - defaultModel.startsWith("chatgpt-4o") || - defaultModel.startsWith("o1")) && - !defaultModel.startsWith("gpt-4o-mini") - ) + if (defaultModel && isGPT4Model(defaultModel)) { defaultModel = ""; + } } const isStability = !!process.env.STABILITY_API_KEY; diff --git a/app/constant.ts b/app/constant.ts index 5759411af17..dcb68ce43bd 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -233,6 +233,8 @@ export const XAI = { export const ChatGLM = { ExampleEndpoint: CHATGLM_BASE_URL, ChatPath: "api/paas/v4/chat/completions", + ImagePath: "api/paas/v4/images/generations", + VideoPath: "api/paas/v4/videos/generations", }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang @@ -303,6 +305,7 @@ export const VISION_MODEL_REGEXES = [ /qwen2-vl/, /gpt-4-turbo(?!.*preview)/, // Matches "gpt-4-turbo" but not "gpt-4-turbo-preview" /^dall-e-3$/, // Matches exactly "dall-e-3" + /glm-4v/, ]; export const EXCLUDE_VISION_MODEL_REGEXES = [/claude-3-5-haiku-20241022/]; @@ -431,6 +434,15 @@ const chatglmModels = [ "glm-4-long", "glm-4-flashx", "glm-4-flash", + "glm-4v-plus", + "glm-4v", + "glm-4v-flash", // free + "cogview-3-plus", + "cogview-3", + "cogview-3-flash", // free + // 目前无法适配轮询任务 + // "cogvideox", + // "cogvideox-flash", // free ]; let seq = 1000; // 内置的模型序号生成器从1000开始 @@ -586,11 +598,6 @@ export const internalAllowedWebDavEndpoints = [ ]; export const DEFAULT_GA_ID = "G-89WN60ZK2E"; -export const PLUGINS = [ - { name: "Plugins", path: Path.Plugins }, - { name: "Stable Diffusion", path: Path.Sd }, - { name: "Search Chat", path: Path.SearchChat }, -]; export const SAAS_CHAT_URL = "https://nextchat.dev/chat"; export const SAAS_CHAT_UTM_URL = "https://nextchat.dev/chat?utm=github"; diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 47be019a809..0a49cef51f8 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -176,7 +176,7 @@ const cn = { }, }, Lang: { - Name: "Language", // ATTENTION: if you wanna add a new translation, please do not translate this value, leave it as `Language` + Name: "Language", // 注意:如果要添加新的翻译,请不要翻译此值,将它保留为 `Language` All: "所有语言", }, Avatar: "头像", @@ -630,7 +630,7 @@ const cn = { Sysmessage: "你是一个助手", }, SearchChat: { - Name: "搜索", + Name: "搜索聊天记录", Page: { Title: "搜索聊天记录", Search: "输入搜索关键词", diff --git a/app/locales/tw.ts b/app/locales/tw.ts index c800ad15d26..f10c793ab80 100644 --- a/app/locales/tw.ts +++ b/app/locales/tw.ts @@ -485,7 +485,7 @@ const tw = { }, }, SearchChat: { - Name: "搜尋", + Name: "搜尋聊天記錄", Page: { Title: "搜尋聊天記錄", Search: "輸入搜尋關鍵詞", diff --git a/app/store/config.ts b/app/store/config.ts index 4256eba925d..45e21b02697 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,5 +1,5 @@ import { LLMModel } from "../client/api"; -import { DalleSize, DalleQuality, DalleStyle } from "../typing"; +import { DalleQuality, DalleStyle, ModelSize } from "../typing"; import { getClientConfig } from "../config/client"; import { DEFAULT_INPUT_TEMPLATE, @@ -78,7 +78,7 @@ export const DEFAULT_CONFIG = { compressProviderName: "", enableInjectSystemPrompts: true, template: config?.template ?? DEFAULT_INPUT_TEMPLATE, - size: "1024x1024" as DalleSize, + size: "1024x1024" as ModelSize, quality: "standard" as DalleQuality, style: "vivid" as DalleStyle, }, diff --git a/app/typing.ts b/app/typing.ts index 0336be75d39..ecb327936fd 100644 --- a/app/typing.ts +++ b/app/typing.ts @@ -11,3 +11,14 @@ export interface RequestMessage { export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792"; export type DalleQuality = "standard" | "hd"; export type DalleStyle = "vivid" | "natural"; + +export type ModelSize = + | "1024x1024" + | "1792x1024" + | "1024x1792" + | "768x1344" + | "864x1152" + | "1344x768" + | "1152x864" + | "1440x720" + | "720x1440"; diff --git a/app/utils.ts b/app/utils.ts index 962e68a101c..810dc7842b1 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant"; import { fetch as tauriStreamFetch } from "./utils/stream"; import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant"; import { getClientConfig } from "./config/client"; +import { ModelSize } from "./typing"; export function trimTopic(topic: string) { // Fix an issue where double quotes still show in the Indonesian language @@ -271,6 +272,28 @@ export function isDalle3(model: string) { return "dall-e-3" === model; } +export function getModelSizes(model: string): ModelSize[] { + if (isDalle3(model)) { + return ["1024x1024", "1792x1024", "1024x1792"]; + } + if (model.toLowerCase().includes("cogview")) { + return [ + "1024x1024", + "768x1344", + "864x1152", + "1344x768", + "1152x864", + "1440x720", + "720x1440", + ]; + } + return []; +} + +export function supportsCustomSize(model: string): boolean { + return getModelSizes(model).length > 0; +} + export function showPlugins(provider: ServiceProvider, model: string) { if ( provider == ServiceProvider.OpenAI || diff --git a/app/utils/model.ts b/app/utils/model.ts index a1b7df1b61e..a1a38a2f81c 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -202,3 +202,52 @@ export function isModelAvailableInServer( const modelTable = collectModelTable(DEFAULT_MODELS, customModels); return modelTable[fullName]?.available === false; } + +/** + * Check if the model name is a GPT-4 related model + * + * @param modelName The name of the model to check + * @returns True if the model is a GPT-4 related model (excluding gpt-4o-mini) + */ +export function isGPT4Model(modelName: string): boolean { + return ( + (modelName.startsWith("gpt-4") || + modelName.startsWith("chatgpt-4o") || + modelName.startsWith("o1")) && + !modelName.startsWith("gpt-4o-mini") + ); +} + +/** + * Checks if a model is not available on any of the specified providers in the server. + * + * @param {string} customModels - A string of custom models, comma-separated. + * @param {string} modelName - The name of the model to check. + * @param {string|string[]} providerNames - A string or array of provider names to check against. + * + * @returns {boolean} True if the model is not available on any of the specified providers, false otherwise. + */ +export function isModelNotavailableInServer( + customModels: string, + modelName: string, + providerNames: string | string[], +): boolean { + // Check DISABLE_GPT4 environment variable + if ( + process.env.DISABLE_GPT4 === "1" && + isGPT4Model(modelName.toLowerCase()) + ) { + return true; + } + + const modelTable = collectModelTable(DEFAULT_MODELS, customModels); + + const providerNamesArray = Array.isArray(providerNames) + ? providerNames + : [providerNames]; + for (const providerName of providerNamesArray) { + const fullName = `${modelName}@${providerName.toLowerCase()}`; + if (modelTable?.[fullName]?.available === true) return false; + } + return true; +} diff --git a/test/model-available.test.ts b/test/model-available.test.ts new file mode 100644 index 00000000000..5c9fa9977d2 --- /dev/null +++ b/test/model-available.test.ts @@ -0,0 +1,80 @@ +import { isModelNotavailableInServer } from "../app/utils/model"; + +describe("isModelNotavailableInServer", () => { + test("test model will return false, which means the model is available", () => { + const customModels = ""; + const modelName = "gpt-4"; + const providerNames = "OpenAI"; + const result = isModelNotavailableInServer( + customModels, + modelName, + providerNames, + ); + expect(result).toBe(false); + }); + + test("test model will return true when model is not available in custom models", () => { + const customModels = "-all,gpt-4o-mini"; + const modelName = "gpt-4"; + const providerNames = "OpenAI"; + const result = isModelNotavailableInServer( + customModels, + modelName, + providerNames, + ); + expect(result).toBe(true); + }); + + test("should respect DISABLE_GPT4 setting", () => { + process.env.DISABLE_GPT4 = "1"; + const result = isModelNotavailableInServer("", "gpt-4", "OpenAI"); + expect(result).toBe(true); + }); + + test("should handle empty provider names", () => { + const result = isModelNotavailableInServer("-all,gpt-4", "gpt-4", ""); + expect(result).toBe(true); + }); + + test("should be case insensitive for model names", () => { + const result = isModelNotavailableInServer("-all,GPT-4", "gpt-4", "OpenAI"); + expect(result).toBe(true); + }); + + test("support passing multiple providers, model unavailable on one of the providers will return true", () => { + const customModels = "-all,gpt-4@google"; + const modelName = "gpt-4"; + const providerNames = ["OpenAI", "Azure"]; + const result = isModelNotavailableInServer( + customModels, + modelName, + providerNames, + ); + expect(result).toBe(true); + }); + + // FIXME: 这个测试用例有问题,需要修复 + // test("support passing multiple providers, model available on one of the providers will return false", () => { + // const customModels = "-all,gpt-4@google"; + // const modelName = "gpt-4"; + // const providerNames = ["OpenAI", "Google"]; + // const result = isModelNotavailableInServer( + // customModels, + // modelName, + // providerNames, + // ); + // expect(result).toBe(false); + // }); + + test("test custom model without setting provider", () => { + const customModels = "-all,mistral-large"; + const modelName = "mistral-large"; + const providerNames = modelName; + const result = isModelNotavailableInServer( + customModels, + modelName, + providerNames, + ); + expect(result).toBe(false); + }); +});