-
Notifications
You must be signed in to change notification settings - Fork 750
Open
Labels
PyTorch (not traced)bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)
Description
🐞Describing the bug
I encounter an issue while trying to reproduce convert nlp model tutorial on documents. The tutorial is not working with the latest versions of tools. Seems like there is something with inputs but the input is in the shape of what the model needed.
Stack Trace
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 1
----> 1 mlmodel = ct.convert(
2 scripted_model,
3 # Range for the sequence dimension to be between [1, 64]
4 inputs=[ct.TensorType(name="context", shape=(ct.RangeDim(1, 64),), dtype=np.int32)],
5 )
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/_converters_entry.py:444, in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)
441 if specification_version is None:
442 specification_version = _set_default_specification_version(exact_target)
--> 444 mlmodel = mil_convert(
445 model,
446 convert_from=exact_source,
447 convert_to=exact_target,
448 inputs=inputs,
449 outputs=outputs_as_tensor_or_image_types, # None or list[ct.ImageType/ct.TensorType]
450 classifier_config=classifier_config,
451 transforms=tuple(transforms),
452 skip_model_load=skip_model_load,
453 compute_units=compute_units,
454 package_dir=package_dir,
455 debug=debug,
456 specification_version=specification_version,
457 )
459 if exact_target == 'milinternal':
460 return mlmodel # Returns the MIL program
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/converter.py:190, in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
151 @_profile
152 def mil_convert(
153 model,
(...)
157 **kwargs
158 ):
159 """
160 Convert model from a specified frontend `convert_from` to a specified
161 converter backend `convert_to`.
(...)
188 See `coremltools.converters.convert`
189 """
--> 190 return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/converter.py:217, in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
214 # To make sure everyone can read and write to this directory (on par with os.mkdir())
215 _os.chmod(weights_dir, _stat.S_IRWXU | _stat.S_IRWXG | _stat.S_IRWXO)
--> 217 proto, mil_program = mil_convert_to_proto(
218 model,
219 convert_from,
220 convert_to,
221 registry,
222 **kwargs
223 )
225 _reset_conversion_state()
227 if convert_to == 'milinternal':
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/converter.py:282, in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
279 kwargs.setdefault("convert_to", convert_to)
280 frontend_converter = frontend_converter_type()
--> 282 prog = frontend_converter(model, **kwargs)
284 if convert_to.lower() != "neuralnetwork":
285 passes = kwargs.get("transforms", list())
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/converter.py:112, in TorchFrontend.__call__(self, *args, **kwargs)
109 def __call__(self, *args, **kwargs):
110 from .frontend.torch import load
--> 112 return load(*args, **kwargs)
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py:57, in load(model_spec, inputs, specification_version, debug, outputs, cut_at_symbols, **kwargs)
55 inputs = _convert_to_torch_inputtype(inputs)
56 converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols, specification_version)
---> 57 return _perform_torch_convert(converter, debug)
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py:96, in _perform_torch_convert(converter, debug)
94 def _perform_torch_convert(converter, debug):
95 try:
---> 96 prog = converter.convert()
97 except RuntimeError as e:
98 if debug and "convert function" in str(e):
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/converter.py:270, in TorchConverter.convert(self)
267 self.convert_const()
269 # Add the rest of the operations
--> 270 convert_nodes(self.context, self.graph)
272 graph_outputs = [self.context[name] for name in self.graph.outputs]
274 # An output can be None when it's a None constant, which happens
275 # in Fairseq MT.
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:103, in convert_nodes(context, graph)
99 if add_op is None:
100 raise RuntimeError(
101 "PyTorch convert function for op '{}' not implemented.".format(node.kind)
102 )
--> 103 add_op(context, node)
105 # We've generated all the outputs the graph needs, terminate conversion.
106 if _all_outputs_present(context, graph):
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:3095, in loop(context, node)
3087 # Must return tuple with same length and types as @loop_vars.
3088 return tuple(
3089 [
3090 iter_var,
3091 ]
3092 + res
3093 )
-> 3095 loop = mb.while_loop(
3096 _cond=_loop_cond, _body=_loop_body, loop_vars=loop_vars, name=name
3097 )
3099 # Make sure the loop returned the expected number of outputs. Note that the
3100 # first two loop outputs are the iteration count and condition.
3101 assert len(loop) - 2 == len(node.outputs)
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/ops/registry.py:178, in SSAOpRegistry.register_op.<locals>.class_wrapper.<locals>.add_op(cls, **kwargs)
175 else:
176 op_cls_to_add = op_reg[op_type]
--> 178 return cls._add_op(op_cls_to_add, **kwargs)
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/builder.py:181, in Builder._add_op(cls, op_cls, **kwargs)
177 new_op.set_inputs(type_inference=False,
178 **missing_optional_vars)
180 curr_block()._insert_op_before(new_op, before_op=before_op)
--> 181 new_op.build_nested_blocks()
182 new_op.type_value_inference()
183 if len(new_op.outputs) == 1:
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py:440, in while_loop.build_nested_blocks(self)
437 v._sym_val = v._sym_val
438 v.consuming_blocks = list()
--> 440 cond_block, body_block, exit_vars = self._build_block(block_inputs)
442 # Verify exit_vars has the same types as loop_vars
443 block_input_type_change = False
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py:374, in while_loop._build_block(self, block_inputs)
371 with Block(block_inputs=block_inputs, outer_op=self,
372 name=block_name) as body_block:
373 body_func = self._body.val
--> 374 exit_vars = body_func(*body_block.inputs)
375 exit_vars = list(exit_vars) if isinstance(exit_vars, (list, tuple)) \
376 else [exit_vars]
377 body_block.set_outputs(exit_vars)
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:3065, in loop.<locals>._loop_body(*loop_vars)
3063 iter_var = loop_vars[0]
3064 inputs = (iter_var,) + loop_vars[2:]
-> 3065 res = convert_block(context, block, inputs)
3067 for input_var, output_var in zip(loop_vars[2:], res[1:]):
3068 if not _shapes_are_equivalent(input_var.shape, output_var.shape):
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:129, in convert_block(context, block, inputs)
126 context.push((block.inputs, inputs))
128 # Add the block ops.
--> 129 convert_nodes(context, block)
131 # Collect the block outputs.
132 outputs = [context[outp] for outp in block.outputs]
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:103, in convert_nodes(context, graph)
99 if add_op is None:
100 raise RuntimeError(
101 "PyTorch convert function for op '{}' not implemented.".format(node.kind)
102 )
--> 103 add_op(context, node)
105 # We've generated all the outputs the graph needs, terminate conversion.
106 if _all_outputs_present(context, graph):
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:3548, in full(context, node)
3546 size = inputs[0]
3547 val = inputs[1].val
-> 3548 result = _make_fill_op(size, val, node.name)
3549 context.add(result)
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:3539, in _make_fill_op(size, val, name)
3537 if isinstance(size, list):
3538 size = mb.concat(values=size, axis=0)
-> 3539 fill = mb.fill(shape=size, value=val, name=name)
3540 return fill
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/ops/registry.py:178, in SSAOpRegistry.register_op.<locals>.class_wrapper.<locals>.add_op(cls, **kwargs)
175 else:
176 op_cls_to_add = op_reg[op_type]
--> 178 return cls._add_op(op_cls_to_add, **kwargs)
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/builder.py:166, in Builder._add_op(cls, op_cls, **kwargs)
161 kwargs = {k: v if not isinstance(v, (list, tuple)) else v[:] for k, v in kwargs.items() if v is not None}
162 kwargs.update(cls._create_vars(
163 input_spec=op_cls.input_spec,
164 op_name=kwargs["name"], before_op=before_op,
165 candidate_kv=kwargs))
--> 166 new_op = op_cls(**kwargs)
168 # Initialize optional input Vars if it wasn't in kwargs
169 default_inputs = new_op.default_inputs()
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/operation.py:182, in Operation.__init__(self, **kwargs)
179 # Set inputs from kwargs
180 input_kv = {k: v for k, v in kwargs.items()
181 if k in self._input_types and v is not None}
--> 182 self._validate_and_set_inputs(input_kv)
183 self._ensure_required_inputs()
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/operation.py:479, in Operation._validate_and_set_inputs(self, input_kvs, no_check_var_types)
476 raise ValueError(msg.format(v_new.sym_type, v_old.sym_type))
477 v_old.remove_child_op(op, no_check_var_types)
--> 479 self.input_spec.validate_inputs(self.name, self.op_type, input_kvs)
481 for name, var in input_kvs.items():
482 # TODO: remove InternalVar check
483 # if not isinstance(var, InternalVar):
484
485 # Remove this operation itself from existing input
486 # Var's child_ops
487 existing_input_var = self._input_vars[name]
File ~/opt/miniconda3/envs/coremltools-env/lib/python3.8/site-packages/coremltools/converters/mil/mil/input_type.py:150, in InputSpec.validate_inputs(self, op_name, op_type, candidate_kvs)
146 if not isinstance(var, InternalVar) and \
147 not input_type.is_compatible(var):
148 msg = msg_prefix + "Input {}=\"{}\" expects " +\
149 "{} but got {}"
--> 150 raise ValueError(msg.format(name, var.name, input_type.type_str,
151 var.sym_type.__type_info__()))
ValueError: Op "137" (op_type: fill) Input shape="136" expects tensor or scalar of dtype from type domain ['int32'] but got tensor[0,fp32]
To Reproduce
import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import coremltools as ct
token_predictor = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()
class FinishMySentence(torch.nn.Module):
def __init__(self, model=None, eos=198):
super(FinishMySentence, self).__init__()
self.eos = torch.tensor([eos])
self.next_token_predictor = model
self.default_token = torch.tensor([0])
def forward(self, x):
sentence = x
token = self.default_token
while token != self.eos:
predictions, _ = self.next_token_predictor(sentence)
token = torch.argmax(predictions[-1, :], dim=0, keepdim=True)
sentence = torch.cat((sentence, token), 0)
return sentence
random_tokens = torch.randint(10000, (5,))
traced_token_predictor = torch.jit.trace(token_predictor, random_tokens)
model = FinishMySentence(model=traced_token_predictor)
scripted_model = torch.jit.script(model)
mlmodel = ct.convert(
scripted_model,
# Range for the sequence dimension to be between [1, 64]
inputs=[ct.TensorType(name="context", shape=(ct.RangeDim(1, 64),), dtype=np.int32)],
)System environment :
torch: 1.13.0
np: 1.23.5
transformers: 4.25.1
coremltools: 6.1
macOS Ventura 13.0
iT-Boyer and huyouare
Metadata
Metadata
Assignees
Labels
PyTorch (not traced)bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)