-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding methods to create constant tensor. #25
Conversation
# Creates a constant Tensor that is added to the graph with a specified name. | ||
# Official documentation of {tf.constant}[https://www.tensorflow.org/versions/r0.9/api_docs/python/constant_op.html#constant]. | ||
# | ||
def constant(name, data, type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Arafatk shouldn't we try and follow as much as possible the same interface (order of arguments) as the Python wrapper? I.e. tf.constant(value, dtype=None, shape=None, name='Const')
https://www.tensorflow.org/versions/r0.9/api_docs/python/constant_op.html#constant
Nice, short and sweet implementation!
As far as I can tell, SciRuby currently prefers NMatrix over NArray big time (and NArray is in the strange development state, it seems). Does NMatrix have functionality you need? Or maybe it is not so complex to implement by hands, sticking with just Ruby Array? |
end | ||
result | ||
end | ||
result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, you don't need result
variable at all: all branches just finish with calculation of result, so it will be returned in this case:
case attribute_type
when "type"
Tensorflow::AttrValue.new(type: value) # just put the value number here
when "tensor"
tensor_element_type = value.type_num
case value.type_num
when Tensorflow::TF_DOUBLE
Tensorflow::AttrValue.new(tensor: Tensorflow::TensorProto.new(dtype: value.type_num, tensor_shape: value.tensor_shape_proto, tensor_content: value.flatten.pack("d*")))
when Tensorflow::TF_INT32
Tensorflow::AttrValue.new(tensor: Tensorflow::TensorProto.new(dtype: value.type_num, tensor_shape: value.tensor_shape_proto, tensor_content: value.flatten.pack("l*")))
when Tensorflow::TF_INT64
Tensorflow::AttrValue.new(tensor: Tensorflow::TensorProto.new(dtype: value.type_num, tensor_shape: value.tensor_shape_proto, tensor_content: value.flatten.pack("q*")))
when Tensorflow::TF_COMPLEX128
tensor_narray = NArray.complex(value.flatten.length)
(0..value.flatten.length - 1).each do |i|
tensor_narray[i] = value.flatten[i]
end
Tensorflow::AttrValue.new(tensor: Tensorflow::TensorProto.new(dtype: value.type_num, tensor_shape: value.tensor_shape_proto, tensor_content: tensor_narray.to_s))
end
end
But it seems code could be simplified even more: there is exactly the same lines in all "tensor"
branch:
Tensorflow::AttrValue.new(tensor: Tensorflow::TensorProto.new(dtype: value.type_num, tensor_shape: value.tensor_shape_proto, tensor_content: "the only difference is here"))
So, what do you think about this:
case attribute_type
when "type"
Tensorflow::AttrValue.new(type: value) # just put the value number here
when "tensor"
tensor_element_type = value.type_num
content =
case value.type_num
when Tensorflow::TF_DOUBLE
value.flatten.pack("d*")
when Tensorflow::TF_INT32
value.flatten.pack("l*")
when Tensorflow::TF_INT64
value.flatten.pack("q*")
when Tensorflow::TF_COMPLEX128
tensor_narray = NArray.complex(value.flatten.length)
(0..value.flatten.length - 1).each do |i|
tensor_narray[i] = value.flatten[i]
end
tensor_narray.to_s
end
Tensorflow::AttrValue.new(
tensor: Tensorflow::TensorProto.new(
dtype: value.type_num,
tensor_shape: value.tensor_shape_proto,
tensor_content: content
)
)
end
Now we can see that content =
calculation could be easily extracted into separate method, making code to look somewhat like this:
case attribute_type
when "type"
Tensorflow::AttrValue.new(type: value) # just put the value number here
when "tensor"
Tensorflow::AttrValue.new(
tensor: Tensorflow::TensorProto.new(
dtype: value.type_num,
tensor_shape: value.tensor_shape_proto,
tensor_content: value_to_tensor_content(value)
)
)
end
And, finally, typically it is a good practice to have else
branch in case
(even if you are sure it is impossible to reach -- impossibility handler, raising an error, could save you a hours of debugging).
@zverok we had NMatrix in shortly, but removed it again (https://github.com/Arafatk/tensorflow.rb/pull/16) since it didn't install without manual intervention on macs, SciRuby/nmatrix#505 (comment). But actually, it looks like they've fixed this by now: SciRuby/nmatrix#530. @Arafatk do you wanna swap in NMatrix instead of NArray (provided specs still pass)? |
@chrhansen Thanks for the suggestion. I have used narray for the simple reason that it has to_s method for Complex arrays. We don't have that feature in Ruby Array and also its not available in NMAtrix. I have tried searching but I could not find any other popular gem that does this. |
Before we review this a few questions
If this works fine for you,then I will add support for a few data types that I had left before in set_type method of tensor(and also for constant).
Another important thing to be noted is that tf.variable will also be made using advantage of constant op.