Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 64 additions & 57 deletions packages/hub/src/lib/commit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,13 @@ export type CommitParams = {
*/
fetch?: typeof fetch;
abortSignal?: AbortSignal;
// Credentials are optional due to custom fetch functions or cookie auth
/**
* @default true
*
* Use xet protocol: https://huggingface.co/blog/xet-on-the-hub to upload, rather than a basic S3 PUT
*/
useXet?: boolean;
// Credentials are optional due to custom fetch functions or cookie auth
} & Partial<CredentialsParams>;

export interface CommitOutput {
Expand Down Expand Up @@ -165,24 +170,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
const repoId = toRepoId(params.repo);
yield { event: "phase", phase: "preuploading" };

let useXet = params.useXet;
if (useXet) {
const info = await (params.fetch ?? fetch)(
`${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}?expand[]=xetEnabled`,
{
headers: {
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
},
}
);

if (!info.ok) {
throw await createApiError(info);
}

const data = await info.json();
useXet = !!data.xetEnabled;
}
let useXet = params.useXet ?? true;

const lfsShas = new Map<string, string | null>();

Expand All @@ -206,10 +194,6 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
const allOperations = (
await Promise.all(
params.operations.map(async (operation) => {
if (operation.operation === "edit" && !useXet) {
throw new Error("Edit operation is not supported when Xet is disabled");
}

if (operation.operation === "edit") {
// Convert EditFile operation to a file operation with SplicedBlob
const splicedBlob = SplicedBlob.create(
Expand Down Expand Up @@ -325,7 +309,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
const payload: ApiLfsBatchRequest = {
operation: "upload",
// multipart is a custom protocol for HF
transfers: ["basic", "multipart"],
transfers: ["basic", "multipart", ...(useXet ? ["xet" as const] : [])],
hash_algo: "sha_256",
...(!params.isPullRequest && {
ref: {
Expand Down Expand Up @@ -363,6 +347,12 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr

const shaToOperation = new Map(operations.map((op, i) => [shas[i], op]));

if (useXet && json.transfer !== "xet") {
useXet = false;
}
let xetRefreshWriteTokenUrl: string | undefined;
let xetSessionId: string | undefined;

if (useXet) {
// First get all the files that are already uploaded out of the way
for (const obj of json.objects) {
Expand All @@ -386,6 +376,17 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
progress: 1,
state: "uploading",
};
} else {
xetRefreshWriteTokenUrl = obj.actions.upload.href;
// Also, obj.actions.upload.header: {
// X-Xet-Cas-Url: string;
// X-Xet-Access-Token: string;
// X-Xet-Token-Expiration: string;
// X-Xet-Session-Id: string;
// }
const headers = new Headers(obj.actions.upload.header);
xetSessionId = headers.get("X-Xet-Session-Id") ?? undefined;
// todo: use other data, like x-xet-cas-url, ...
}
}
const source = (async function* () {
Expand All @@ -395,43 +396,49 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
continue;
}
abortSignal?.throwIfAborted();

yield { content: op.content, path: op.path, sha256: obj.oid };
}
})();
const sources = splitAsyncGenerator(source, 5);
yield* eventToGenerator((yieldCallback, returnCallback, rejectCallback) =>
Promise.all(
sources.map(async function (source) {
for await (const event of uploadShards(source, {
fetch: params.fetch,
accessToken,
hubUrl: params.hubUrl ?? HUB_URL,
repo: repoId,
// todo: maybe leave empty if PR?
rev: params.branch ?? "main",
isPullRequest: params.isPullRequest,
yieldCallback: (event) => yieldCallback({ ...event, state: "uploading" }),
})) {
if (event.event === "file") {
yieldCallback({
event: "fileProgress" as const,
path: event.path,
progress: 1,
state: "uploading" as const,
});
} else if (event.event === "fileProgress") {
yieldCallback({
event: "fileProgress" as const,
path: event.path,
progress: event.progress,
state: "uploading" as const,
});
if (xetRefreshWriteTokenUrl) {
const xetRefreshWriteTokenUrlFixed = xetRefreshWriteTokenUrl;
const sources = splitAsyncGenerator(source, 5);
yield* eventToGenerator((yieldCallback, returnCallback, rejectCallback) =>
Promise.all(
sources.map(async function (source) {
for await (const event of uploadShards(source, {
fetch: params.fetch,
accessToken,
hubUrl: params.hubUrl ?? HUB_URL,
repo: repoId,
xetRefreshWriteTokenUrl: xetRefreshWriteTokenUrlFixed,
xetSessionId,
// todo: maybe leave empty if PR?
rev: params.branch ?? "main",
isPullRequest: params.isPullRequest,
yieldCallback: (event) => yieldCallback({ ...event, state: "uploading" }),
})) {
if (event.event === "file") {
yieldCallback({
event: "fileProgress" as const,
path: event.path,
progress: 1,
state: "uploading" as const,
});
} else if (event.event === "fileProgress") {
yieldCallback({
event: "fileProgress" as const,
path: event.path,
progress: event.progress,
state: "uploading" as const,
});
}
}
}
})
).then(() => returnCallback(undefined), rejectCallback)
);
})
).then(() => returnCallback(undefined), rejectCallback)
);
} else {
// No LFS file to upload
}
} else {
yield* eventToGenerator<CommitProgressEvent, void>((yieldCallback, returnCallback, rejectCallback) => {
return promisesQueueStreaming(
Expand Down
4 changes: 2 additions & 2 deletions packages/hub/src/types/api/api-commit.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export interface ApiLfsBatchRequest {
/// github.com/git-lfs/git-lfs/blob/master/docs/api/batch.md
operation: "download" | "upload";
transfers?: string[];
transfers?: Array<ApiLfsResponseTransfer>;
/**
* Optional object describing the server ref that the objects belong to. Note: Added in v2.4.
*
Expand Down Expand Up @@ -29,7 +29,7 @@ export interface ApiLfsBatchResponse {
objects: ApiLfsResponseObject[];
}

export type ApiLfsResponseTransfer = "basic" | "multipart";
export type ApiLfsResponseTransfer = "basic" | "multipart" | "xet";

export interface ApiLfsCompleteMultipartRequest {
oid: string;
Expand Down
4 changes: 4 additions & 0 deletions packages/hub/src/utils/uploadShards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ export const SHARD_MAGIC_TAG = new Uint8Array([
interface UploadShardsParams {
accessToken: string | undefined;
hubUrl: string;
xetRefreshWriteTokenUrl: string;
xetSessionId: string | undefined;
fetch?: typeof fetch;
repo: RepoId;
rev: string;
Expand Down Expand Up @@ -365,6 +367,7 @@ async function uploadXorb(
body: xorb.xorb,
headers: {
Authorization: `Bearer ${token.accessToken}`,
...(params.xetSessionId ? { "X-Xet-Session-Id": params.xetSessionId } : {}),
},
...{
progressHint: {
Expand Down Expand Up @@ -394,6 +397,7 @@ async function uploadShard(shard: Uint8Array, params: UploadShardsParams) {
body: shard,
headers: {
Authorization: `Bearer ${token.accessToken}`,
...(params.xetSessionId ? { "X-Xet-Session-Id": params.xetSessionId } : {}),
},
});

Expand Down
8 changes: 5 additions & 3 deletions packages/hub/src/utils/xetWriteToken.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface XetWriteTokenParams {
repo: RepoId;
rev: string;
isPullRequest?: boolean;
xetRefreshWriteTokenUrl: string | undefined;
}

const JWT_SAFETY_PERIOD = 60_000;
Expand Down Expand Up @@ -47,9 +48,10 @@ export async function xetWriteToken(params: XetWriteTokenParams): Promise<{ acce

const promise = (async () => {
const resp = await (params.fetch ?? fetch)(
`${params.hubUrl}/api/${params.repo.type}s/${params.repo.name}/xet-write-token/${encodeURIComponent(
params.rev
)}` + (params.isPullRequest ? "?create_pr=1" : ""),
params.xetRefreshWriteTokenUrl ??
`${params.hubUrl}/api/${params.repo.type}s/${params.repo.name}/xet-write-token/${encodeURIComponent(
params.rev
)}` + (params.isPullRequest ? "?create_pr=1" : ""),
{
headers: params.accessToken
? {
Expand Down
Loading