-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: update openclip loader #782
Conversation
Codecov Report
@@ Coverage Diff @@
## main #782 +/- ##
==========================================
- Coverage 86.28% 85.72% -0.56%
==========================================
Files 21 21
Lines 1108 1121 +13
==========================================
+ Hits 956 961 +5
- Misses 152 160 +8
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
if pretrained.lower() == 'openai': | ||
try: | ||
# loading JIT archive | ||
model = torch.jit.load( | ||
model_path, map_location=device if jit else "cpu" | ||
).eval() | ||
state_dict = None | ||
except RuntimeError: | ||
# loading saved state dict | ||
if jit: | ||
warnings.warn( | ||
f"File {model_path} is not a JIT archive. Loading as a state dict instead" | ||
) | ||
jit = False | ||
state_dict = torch.load(model_path, map_location="cpu") | ||
if not jit: | ||
try: | ||
model = build_model_from_openai_state_dict( | ||
state_dict or model.state_dict() | ||
).to(device) | ||
except KeyError: | ||
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} | ||
model = build_model_from_openai_state_dict(sd).to(device) | ||
if str(device) == "cpu": | ||
model.float() | ||
else: | ||
# patch the device names | ||
device_holder = torch.jit.trace( | ||
lambda: torch.ones([]).to(torch.device(device)), | ||
example_inputs=[], | ||
) | ||
device_node = [ | ||
n | ||
for n in device_holder.graph.findAllNodes("prim::Constant") | ||
if "Device" in repr(n) | ||
][-1] | ||
|
||
def patch_device(module): | ||
try: | ||
graphs = [module.graph] if hasattr(module, "graph") else [] | ||
except RuntimeError: | ||
graphs = [] | ||
|
||
if hasattr(module, "forward1"): | ||
graphs.append(module.forward1.graph) | ||
|
||
for graph in graphs: | ||
for node in graph.findAllNodes("prim::Constant"): | ||
if "value" in node.attributeNames() and str( | ||
node["value"] | ||
).startswith("cuda"): | ||
node.copyAttributes(device_node) | ||
|
||
model.apply(patch_device) | ||
patch_device(model.encode_image) | ||
patch_device(model.encode_text) | ||
|
||
# patch dtype to float32 on CPU | ||
if device == "cpu": | ||
float_holder = torch.jit.trace( | ||
lambda: torch.ones([]).float(), example_inputs=[] | ||
) | ||
float_input = list( | ||
float_holder.graph.findNode("aten::to").inputs() | ||
)[1] | ||
float_node = float_input.node() | ||
|
||
def patch_float(module): | ||
try: | ||
graphs = ( | ||
[module.graph] if hasattr(module, "graph") else [] | ||
) | ||
except RuntimeError: | ||
graphs = [] | ||
|
||
if hasattr(module, "forward1"): | ||
graphs.append(module.forward1.graph) | ||
|
||
for graph in graphs: | ||
for node in graph.findAllNodes("aten::to"): | ||
inputs = list(node.inputs()) | ||
for i in [ | ||
1, | ||
2, | ||
]: # dtype can be the second or third argument to aten::to() | ||
if inputs[i].node()["value"] == 5: | ||
inputs[i].node().copyAttributes(float_node) | ||
|
||
model.apply(patch_float) | ||
patch_float(model.encode_image) | ||
patch_float(model.encode_text) | ||
model.float() | ||
|
||
# ensure image_size attr available at consistent location for both jit and non-jit | ||
model.visual.image_size = model.input_resolution.item() | ||
if precision == "fp32": | ||
model = model.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can simply use load_openai_model()
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't do it because load_openai_model contains openclip's download method which disables us to download model from our s3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, load_openai_model
also accept model_path as the parameter.
|
||
model = CLIP(**model_cfg) | ||
|
||
if pretrained: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in our case, pretrained
cannot be None/empty.
if pretrained: | ||
if model_path: | ||
model.load_state_dict(load_state_dict(model_path)) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else
is not necessary?
if precision == "fp16": | ||
convert_weights_to_fp16(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the logic, if device=cuda
, then we use fp16
self._model_name = model_name | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should make sure all of the models are uploaded to our s3 bucket.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i uploaded pt models, should be available in any minutes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update: any hours 😅poor internet
@@ -74,4 +74,4 @@ def encode_text( | |||
) | |||
|
|||
def encode_image(self, pixel_values: torch.Tensor, **kwargs): | |||
return self._model.encode_image(pixel_values) | |||
return self._model.encode_image(pixel_values, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, we are sure here encode_image
only accepts pixel_values
. Thus, kwargs
should not be considered.
model.load_state_dict(load_state_dict(model_path)) | ||
model.to(device=torch.device(device)) | ||
|
||
if device == 'cuda': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if device == 'cuda': | |
if str(device) == 'cuda': |
to support independent download process and make precision adapted to device to solve VRAM issue
d361f72
to
ab23235
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
to support independent download process and make precision adapted to device to solve VRAM issue