diff --git a/packages/hub/src/lib/commit.ts b/packages/hub/src/lib/commit.ts index 0123d88cd..ad2365a11 100644 --- a/packages/hub/src/lib/commit.ts +++ b/packages/hub/src/lib/commit.ts @@ -130,403 +130,416 @@ export async function* commitIter(params: CommitParams): AsyncGenerator(); - const allOperations = await Promise.all( - params.operations.map(async (operation) => { - if (operation.operation !== "addOrUpdate") { - return operation; - } + const abortController = new AbortController(); + const abortSignal = abortController.signal; - if (!(operation.content instanceof URL)) { - /** TS trick to enforce `content` to be a `Blob` */ - return { ...operation, content: operation.content }; - } + if (params.abortSignal) { + params.abortSignal.addEventListener("abort", () => abortController.abort()); + } - const lazyBlob = await createBlob(operation.content, { fetch: params.fetch }); + try { + const allOperations = await Promise.all( + params.operations.map(async (operation) => { + if (operation.operation !== "addOrUpdate") { + return operation; + } - params.abortSignal?.throwIfAborted(); + if (!(operation.content instanceof URL)) { + /** TS trick to enforce `content` to be a `Blob` */ + return { ...operation, content: operation.content }; + } - return { - ...operation, - content: lazyBlob, - }; - }) - ); + const lazyBlob = await createBlob(operation.content, { fetch: params.fetch }); - const gitAttributes = allOperations.filter(isFileOperation).find((op) => op.path === ".gitattributes")?.content; + abortSignal?.throwIfAborted(); - for (const operations of chunk(allOperations.filter(isFileOperation), 100)) { - const payload: ApiPreuploadRequest = { - gitAttributes: gitAttributes && (await gitAttributes.text()), - files: await Promise.all( - operations.map(async (operation) => ({ - path: operation.path, - size: operation.content.size, - sample: base64FromBytes(new Uint8Array(await operation.content.slice(0, 512).arrayBuffer())), - })) - ), - }; - - params.abortSignal?.throwIfAborted(); - - const res = await (params.fetch ?? fetch)( - `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/preupload/${encodeURIComponent( - params.branch ?? "main" - )}` + (params.isPullRequest ? "?create_pr=1" : ""), - { - method: "POST", - headers: { - ...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }), - "Content-Type": "application/json", - }, - body: JSON.stringify(payload), - signal: params.abortSignal, - } + return { + ...operation, + content: lazyBlob, + }; + }) ); - if (!res.ok) { - throw await createApiError(res); - } + const gitAttributes = allOperations.filter(isFileOperation).find((op) => op.path === ".gitattributes")?.content; + + for (const operations of chunk(allOperations.filter(isFileOperation), 100)) { + const payload: ApiPreuploadRequest = { + gitAttributes: gitAttributes && (await gitAttributes.text()), + files: await Promise.all( + operations.map(async (operation) => ({ + path: operation.path, + size: operation.content.size, + sample: base64FromBytes(new Uint8Array(await operation.content.slice(0, 512).arrayBuffer())), + })) + ), + }; + + abortSignal?.throwIfAborted(); - const json: ApiPreuploadResponse = await res.json(); + const res = await (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/preupload/${encodeURIComponent( + params.branch ?? "main" + )}` + (params.isPullRequest ? "?create_pr=1" : ""), + { + method: "POST", + headers: { + ...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }), + "Content-Type": "application/json", + }, + body: JSON.stringify(payload), + signal: abortSignal, + } + ); + + if (!res.ok) { + throw await createApiError(res); + } - for (const file of json.files) { - if (file.uploadMode === "lfs") { - lfsShas.set(file.path, null); + const json: ApiPreuploadResponse = await res.json(); + + for (const file of json.files) { + if (file.uploadMode === "lfs") { + lfsShas.set(file.path, null); + } } } - } - yield { event: "phase", phase: "uploadingLargeFiles" }; - - for (const operations of chunk( - allOperations.filter(isFileOperation).filter((op) => lfsShas.has(op.path)), - 100 - )) { - const shas = yield* eventToGenerator< - { event: "fileProgress"; type: "hashing"; path: string; progress: number }, - string[] - >((yieldCallback, returnCallback, rejectCallack) => { - return promisesQueue( - operations.map((op) => async () => { - const iterator = sha256(op.content, { useWebWorker: params.useWebWorkers, abortSignal: params.abortSignal }); - let res: IteratorResult; - do { - res = await iterator.next(); - if (!res.done) { - yieldCallback({ event: "fileProgress", path: op.path, progress: res.value, type: "hashing" }); - } - } while (!res.done); - const sha = res.value; - lfsShas.set(op.path, res.value); - return sha; - }), - CONCURRENT_SHAS - ).then(returnCallback, rejectCallack); - }); - - params.abortSignal?.throwIfAborted(); - - const payload: ApiLfsBatchRequest = { - operation: "upload", - // multipart is a custom protocol for HF - transfers: ["basic", "multipart"], - hash_algo: "sha_256", - ref: { - name: params.branch ?? "main", - }, - objects: operations.map((op, i) => ({ - oid: shas[i], - size: op.content.size, - })), - }; - - const res = await (params.fetch ?? fetch)( - `${params.hubUrl ?? HUB_URL}/${repoId.type === "model" ? "" : repoId.type + "s/"}${ - repoId.name - }.git/info/lfs/objects/batch`, - { - method: "POST", - headers: { - ...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }), - Accept: "application/vnd.git-lfs+json", - "Content-Type": "application/vnd.git-lfs+json", + yield { event: "phase", phase: "uploadingLargeFiles" }; + + for (const operations of chunk( + allOperations.filter(isFileOperation).filter((op) => lfsShas.has(op.path)), + 100 + )) { + const shas = yield* eventToGenerator< + { event: "fileProgress"; type: "hashing"; path: string; progress: number }, + string[] + >((yieldCallback, returnCallback, rejectCallack) => { + return promisesQueue( + operations.map((op) => async () => { + const iterator = sha256(op.content, { useWebWorker: params.useWebWorkers, abortSignal: abortSignal }); + let res: IteratorResult; + do { + res = await iterator.next(); + if (!res.done) { + yieldCallback({ event: "fileProgress", path: op.path, progress: res.value, type: "hashing" }); + } + } while (!res.done); + const sha = res.value; + lfsShas.set(op.path, res.value); + return sha; + }), + CONCURRENT_SHAS + ).then(returnCallback, rejectCallack); + }); + + abortSignal?.throwIfAborted(); + + const payload: ApiLfsBatchRequest = { + operation: "upload", + // multipart is a custom protocol for HF + transfers: ["basic", "multipart"], + hash_algo: "sha_256", + ref: { + name: params.branch ?? "main", }, - body: JSON.stringify(payload), - signal: params.abortSignal, - } - ); + objects: operations.map((op, i) => ({ + oid: shas[i], + size: op.content.size, + })), + }; - if (!res.ok) { - throw await createApiError(res); - } + const res = await (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/${repoId.type === "model" ? "" : repoId.type + "s/"}${ + repoId.name + }.git/info/lfs/objects/batch`, + { + method: "POST", + headers: { + ...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }), + Accept: "application/vnd.git-lfs+json", + "Content-Type": "application/vnd.git-lfs+json", + }, + body: JSON.stringify(payload), + signal: abortSignal, + } + ); - const json: ApiLfsBatchResponse = await res.json(); - const batchRequestId = res.headers.get("X-Request-Id") || undefined; + if (!res.ok) { + throw await createApiError(res); + } - const shaToOperation = new Map(operations.map((op, i) => [shas[i], op])); + const json: ApiLfsBatchResponse = await res.json(); + const batchRequestId = res.headers.get("X-Request-Id") || undefined; - yield* eventToGenerator((yieldCallback, returnCallback, rejectCallback) => { - return promisesQueueStreaming( - json.objects.map((obj) => async () => { - const op = shaToOperation.get(obj.oid); + const shaToOperation = new Map(operations.map((op, i) => [shas[i], op])); - if (!op) { - throw new InvalidApiResponseFormatError("Unrequested object ID in response"); - } + yield* eventToGenerator((yieldCallback, returnCallback, rejectCallback) => { + return promisesQueueStreaming( + json.objects.map((obj) => async () => { + const op = shaToOperation.get(obj.oid); - params.abortSignal?.throwIfAborted(); + if (!op) { + throw new InvalidApiResponseFormatError("Unrequested object ID in response"); + } - if (obj.error) { - const errorMessage = `Error while doing LFS batch call for ${operations[shas.indexOf(obj.oid)].path}: ${ - obj.error.message - }${batchRequestId ? ` - Request ID: ${batchRequestId}` : ""}`; - throw new HubApiError(res.url, obj.error.code, batchRequestId, errorMessage); - } - if (!obj.actions?.upload) { - // Already uploaded + abortSignal?.throwIfAborted(); + + if (obj.error) { + const errorMessage = `Error while doing LFS batch call for ${operations[shas.indexOf(obj.oid)].path}: ${ + obj.error.message + }${batchRequestId ? ` - Request ID: ${batchRequestId}` : ""}`; + throw new HubApiError(res.url, obj.error.code, batchRequestId, errorMessage); + } + if (!obj.actions?.upload) { + // Already uploaded + yieldCallback({ + event: "fileProgress", + path: op.path, + progress: 1, + type: "uploading", + }); + return; + } yieldCallback({ event: "fileProgress", path: op.path, - progress: 1, + progress: 0, type: "uploading", }); - return; - } - yieldCallback({ - event: "fileProgress", - path: op.path, - progress: 0, - type: "uploading", - }); - const content = op.content; - const header = obj.actions.upload.header; - if (header?.chunk_size) { - const chunkSize = parseInt(header.chunk_size); - - // multipart upload - // parts are in upload.header['00001'] to upload.header['99999'] - - const completionUrl = obj.actions.upload.href; - const parts = Object.keys(header).filter((key) => /^[0-9]+$/.test(key)); - - if (parts.length !== Math.ceil(content.size / chunkSize)) { - throw new Error("Invalid server response to upload large LFS file, wrong number of parts"); - } - - const completeReq: ApiLfsCompleteMultipartRequest = { - oid: obj.oid, - parts: parts.map((part) => ({ - partNumber: +part, - etag: "", - })), - }; - - // Defined here so that it's not redefined at each iteration (and the caller can tell it's for the same file) - const progressCallback = (progress: number) => - yieldCallback({ event: "fileProgress", path: op.path, progress, type: "uploading" }); - - await promisesQueueStreaming( - parts.map((part) => async () => { - params.abortSignal?.throwIfAborted(); - - const index = parseInt(part) - 1; - const slice = content.slice(index * chunkSize, (index + 1) * chunkSize); - - const res = await (params.fetch ?? fetch)(header[part], { - method: "PUT", - /** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */ - body: slice instanceof WebBlob && isFrontend ? await slice.arrayBuffer() : slice, - signal: params.abortSignal, - ...({ - progressHint: { - path: op.path, - part: index, - numParts: parts.length, - progressCallback, - }, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any), - }); - - if (!res.ok) { - throw await createApiError(res, { - requestId: batchRequestId, - message: `Error while uploading part ${part} of ${ - operations[shas.indexOf(obj.oid)].path - } to LFS storage`, + const content = op.content; + const header = obj.actions.upload.header; + if (header?.chunk_size) { + const chunkSize = parseInt(header.chunk_size); + + // multipart upload + // parts are in upload.header['00001'] to upload.header['99999'] + + const completionUrl = obj.actions.upload.href; + const parts = Object.keys(header).filter((key) => /^[0-9]+$/.test(key)); + + if (parts.length !== Math.ceil(content.size / chunkSize)) { + throw new Error("Invalid server response to upload large LFS file, wrong number of parts"); + } + + const completeReq: ApiLfsCompleteMultipartRequest = { + oid: obj.oid, + parts: parts.map((part) => ({ + partNumber: +part, + etag: "", + })), + }; + + // Defined here so that it's not redefined at each iteration (and the caller can tell it's for the same file) + const progressCallback = (progress: number) => + yieldCallback({ event: "fileProgress", path: op.path, progress, type: "uploading" }); + + await promisesQueueStreaming( + parts.map((part) => async () => { + abortSignal?.throwIfAborted(); + + const index = parseInt(part) - 1; + const slice = content.slice(index * chunkSize, (index + 1) * chunkSize); + + const res = await (params.fetch ?? fetch)(header[part], { + method: "PUT", + /** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */ + body: slice instanceof WebBlob && isFrontend ? await slice.arrayBuffer() : slice, + signal: abortSignal, + ...({ + progressHint: { + path: op.path, + part: index, + numParts: parts.length, + progressCallback, + }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any), }); - } - const eTag = res.headers.get("ETag"); + if (!res.ok) { + throw await createApiError(res, { + requestId: batchRequestId, + message: `Error while uploading part ${part} of ${ + operations[shas.indexOf(obj.oid)].path + } to LFS storage`, + }); + } - if (!eTag) { - throw new Error("Cannot get ETag of part during multipart upload"); - } + const eTag = res.headers.get("ETag"); - completeReq.parts[Number(part) - 1].etag = eTag; - }), - MULTIPART_PARALLEL_UPLOAD - ); + if (!eTag) { + throw new Error("Cannot get ETag of part during multipart upload"); + } - params.abortSignal?.throwIfAborted(); + completeReq.parts[Number(part) - 1].etag = eTag; + }), + MULTIPART_PARALLEL_UPLOAD + ); - const res = await (params.fetch ?? fetch)(completionUrl, { - method: "POST", - body: JSON.stringify(completeReq), - headers: { - Accept: "application/vnd.git-lfs+json", - "Content-Type": "application/vnd.git-lfs+json", - }, - signal: params.abortSignal, - }); + abortSignal?.throwIfAborted(); - if (!res.ok) { - throw await createApiError(res, { - requestId: batchRequestId, - message: `Error completing multipart upload of ${ - operations[shas.indexOf(obj.oid)].path - } to LFS storage`, + const res = await (params.fetch ?? fetch)(completionUrl, { + method: "POST", + body: JSON.stringify(completeReq), + headers: { + Accept: "application/vnd.git-lfs+json", + "Content-Type": "application/vnd.git-lfs+json", + }, + signal: abortSignal, }); - } - yieldCallback({ - event: "fileProgress", - path: op.path, - progress: 1, - type: "uploading", - }); - } else { - const res = await (params.fetch ?? fetch)(obj.actions.upload.href, { - method: "PUT", - headers: { - ...(batchRequestId ? { "X-Request-Id": batchRequestId } : undefined), - }, - /** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */ - body: content instanceof WebBlob && isFrontend ? await content.arrayBuffer() : content, - signal: params.abortSignal, - ...({ - progressHint: { - path: op.path, - progressCallback: (progress: number) => - yieldCallback({ - event: "fileProgress", - path: op.path, - progress, - type: "uploading", - }), - }, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any), - }); + if (!res.ok) { + throw await createApiError(res, { + requestId: batchRequestId, + message: `Error completing multipart upload of ${ + operations[shas.indexOf(obj.oid)].path + } to LFS storage`, + }); + } - if (!res.ok) { - throw await createApiError(res, { - requestId: batchRequestId, - message: `Error while uploading ${operations[shas.indexOf(obj.oid)].path} to LFS storage`, + yieldCallback({ + event: "fileProgress", + path: op.path, + progress: 1, + type: "uploading", + }); + } else { + const res = await (params.fetch ?? fetch)(obj.actions.upload.href, { + method: "PUT", + headers: { + ...(batchRequestId ? { "X-Request-Id": batchRequestId } : undefined), + }, + /** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */ + body: content instanceof WebBlob && isFrontend ? await content.arrayBuffer() : content, + signal: abortSignal, + ...({ + progressHint: { + path: op.path, + progressCallback: (progress: number) => + yieldCallback({ + event: "fileProgress", + path: op.path, + progress, + type: "uploading", + }), + }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any), }); - } - - yieldCallback({ - event: "fileProgress", - path: op.path, - progress: 1, - type: "uploading", - }); - } - }), - CONCURRENT_LFS_UPLOADS - ).then(returnCallback, rejectCallback); - }); - } - params.abortSignal?.throwIfAborted(); + if (!res.ok) { + throw await createApiError(res, { + requestId: batchRequestId, + message: `Error while uploading ${operations[shas.indexOf(obj.oid)].path} to LFS storage`, + }); + } - yield { event: "phase", phase: "committing" }; + yieldCallback({ + event: "fileProgress", + path: op.path, + progress: 1, + type: "uploading", + }); + } + }), + CONCURRENT_LFS_UPLOADS + ).then(returnCallback, rejectCallback); + }); + } - return yield* eventToGenerator( - async (yieldCallback, returnCallback, rejectCallback) => - (params.fetch ?? fetch)( - `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commit/${encodeURIComponent( - params.branch ?? "main" - )}` + (params.isPullRequest ? "?create_pr=1" : ""), - { - method: "POST", - headers: { - ...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }), - "Content-Type": "application/x-ndjson", - }, - body: [ - { - key: "header", - value: { - summary: params.title, - description: params.description, - parentCommit: params.parentCommit, - } satisfies ApiCommitHeader, + abortSignal?.throwIfAborted(); + + yield { event: "phase", phase: "committing" }; + + return yield* eventToGenerator( + async (yieldCallback, returnCallback, rejectCallback) => + (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commit/${encodeURIComponent( + params.branch ?? "main" + )}` + (params.isPullRequest ? "?create_pr=1" : ""), + { + method: "POST", + headers: { + ...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }), + "Content-Type": "application/x-ndjson", }, - ...((await Promise.all( - allOperations.map((operation) => { - if (isFileOperation(operation)) { - const sha = lfsShas.get(operation.path); - if (sha) { - return { - key: "lfsFile", - value: { - path: operation.path, - algo: "sha256", - size: operation.content.size, - oid: sha, - } satisfies ApiCommitLfsFile, - }; + body: [ + { + key: "header", + value: { + summary: params.title, + description: params.description, + parentCommit: params.parentCommit, + } satisfies ApiCommitHeader, + }, + ...((await Promise.all( + allOperations.map((operation) => { + if (isFileOperation(operation)) { + const sha = lfsShas.get(operation.path); + if (sha) { + return { + key: "lfsFile", + value: { + path: operation.path, + algo: "sha256", + size: operation.content.size, + oid: sha, + } satisfies ApiCommitLfsFile, + }; + } } - } - - return convertOperationToNdJson(operation); - }) - )) satisfies ApiCommitOperation[]), - ] - .map((x) => JSON.stringify(x)) - .join("\n"), - signal: params.abortSignal, - ...({ - progressHint: { - progressCallback: (progress: number) => { - // For now, we display equal progress for all files - // We could compute the progress based on the size of `convertOperationToNdJson` for each of the files instead - for (const op of allOperations) { - if (isFileOperation(op) && !lfsShas.has(op.path)) { - yieldCallback({ - event: "fileProgress", - path: op.path, - progress, - type: "uploading", - }); + + return convertOperationToNdJson(operation); + }) + )) satisfies ApiCommitOperation[]), + ] + .map((x) => JSON.stringify(x)) + .join("\n"), + signal: abortSignal, + ...({ + progressHint: { + progressCallback: (progress: number) => { + // For now, we display equal progress for all files + // We could compute the progress based on the size of `convertOperationToNdJson` for each of the files instead + for (const op of allOperations) { + if (isFileOperation(op) && !lfsShas.has(op.path)) { + yieldCallback({ + event: "fileProgress", + path: op.path, + progress, + type: "uploading", + }); + } } - } + }, }, - }, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } as any), - } - ) - .then(async (res) => { - if (!res.ok) { - throw await createApiError(res); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any), } + ) + .then(async (res) => { + if (!res.ok) { + throw await createApiError(res); + } - const json = await res.json(); + const json = await res.json(); - returnCallback({ - pullRequestUrl: json.pullRequestUrl, - commit: { - oid: json.commitOid, - url: json.commitUrl, - }, - hookOutput: json.hookOutput, - }); - }) - .catch(rejectCallback) - ); + returnCallback({ + pullRequestUrl: json.pullRequestUrl, + commit: { + oid: json.commitOid, + url: json.commitUrl, + }, + hookOutput: json.hookOutput, + }); + }) + .catch(rejectCallback) + ); + } catch (err) { + // For parallel requests, cancel them all if one fails + abortController.abort(); + throw err; + } } export async function commit(params: CommitParams): Promise {