Skip to content

Conversation

@bobbyliujb
Copy link

Summary:

Context

In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

This diff

  • added backend_return_whole_row flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
  • added read_only_ flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
  • added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
  • updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
  • by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision:
D77604158

Privacy Context Container: L1138451

@netlify
Copy link

netlify bot commented Jul 1, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 2940240
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-fbgemm-docs/deploys/686c1cf56a0b7c0008ef7eb1
😎 Deploy Preview https://deploy-preview-4429--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (meta-pytorch#3148)

Summary:
X-link: pytorch/FBGEMM#4429

X-link: facebookresearch/FBGEMM#1495


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (meta-pytorch#3148)

Summary:
X-link: pytorch/FBGEMM#4429

X-link: facebookresearch/FBGEMM#1495


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (meta-pytorch#3148)

Summary:
X-link: pytorch/FBGEMM#4429

X-link: facebookresearch/FBGEMM#1495


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (meta-pytorch#3148)

Summary:
X-link: pytorch/FBGEMM#4429

X-link: facebookresearch/FBGEMM#1495

Pull Request resolved: meta-pytorch#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:
Pull Request resolved: pytorch#4429

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
@bobbyliujb bobbyliujb force-pushed the export-D77604158 branch 2 times, most recently from 8c7a58c to 0c5b161 Compare July 2, 2025 17:37
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 2, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Reviewed By: emlin

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 2, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Reviewed By: emlin

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:
Pull Request resolved: pytorch#4429

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:
Pull Request resolved: pytorch#4429

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision:
D77604158

Privacy Context Container: L1138451

Reviewed By: emlin
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:
Pull Request resolved: pytorch#4429

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:
Pull Request resolved: pytorch#4429

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:
Pull Request resolved: pytorch#4429

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision:
D77604158

Privacy Context Container: L1138451

Reviewed By: emlin
…kend for checkpoint saving/loading (pytorch#4429)

Summary:
Pull Request resolved: pytorch#4429

X-link: facebookresearch/FBGEMM#1495

X-link: meta-pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in f2e75f5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants