Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static types for the Keras API #2

Open
shadaj opened this issue Mar 11, 2019 · 4 comments
Open

Static types for the Keras API #2

shadaj opened this issue Mar 11, 2019 · 4 comments

Comments

@shadaj
Copy link
Owner

shadaj commented Mar 11, 2019

The recommended API for use with TensorFlow is now Keras, so we should have static type definitions for it.

@Avasil
Copy link

Avasil commented Mar 17, 2020

Hi @shadaj I'm interested in contributing Keras API facades - were there any big changes since the last scalapy-tensorflow update or can I do it in similar way to already defined facades?

@shadaj
Copy link
Owner Author

shadaj commented Mar 20, 2020

HI @Avasil, that's awesome to hear! Not really, you should be able to add in facades just like the existing ones.

@Avasil
Copy link

Avasil commented Mar 27, 2020

@shadaj Do you have any advice on figuring out proper type to return in facade?

I want to try mnist example with static types but I struggle with:

import me.shadaj.scalapy.py
import me.shadaj.scalapy.numpy.NDArray

@py.native trait Mnist extends py.Object {
  def load_data(): ((NDArray[Long], NDArray[Long]), (NDArray[Long], NDArray[Long])) = py.native
}

// somewhere else

py.module("keras.datasets.mnist").as[Mnist].load_data()

I also tried py.module("tensorflow.keras.datasets.mnist").as[Mnist].load_data()

[error] Exception in thread "main" scala.MatchError: jep.NDArray@a64a2421 (of class jep.NDArray)
[error] 	at me.shadaj.scalapy.py.JepPyValue.getLong(JepInterpreter.scala:193)
[error] 	at me.shadaj.scalapy.py.JepPyValue.getLong$(JepInterpreter.scala:193)
[error] 	at me.shadaj.scalapy.py.JepJavaPyValue.getLong(JepInterpreter.scala:258)
[error] 	at me.shadaj.scalapy.py.Reader$$anon$6.read(Reader.scala:42)
[error] 	at me.shadaj.scalapy.py.Reader$$anon$6.read(Reader.scala:41)
[error] 	at me.shadaj.scalapy.py.Any.as(Any.scala:15)
[error] 	at me.shadaj.scalapy.py.Any.as$(Any.scala:15)
[error] 	at me.shadaj.scalapy.py.FacadeValueProvider.as(Facades.scala:5)
[error] 	at me.shadaj.scalapy.numpy.NDArray.apply(NDArray.scala:31)
[error] 	at me.shadaj.scalapy.numpy.NDArray.$anonfun$iterator$1(NDArray.scala:33)
[error] 	at me.shadaj.scalapy.numpy.NDArray.$anonfun$iterator$1$adapted(NDArray.scala:33)
[error] 	at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
[error] 	at scala.collection.Iterator.foreach(Iterator.scala:941)
[error] 	at scala.collection.Iterator.foreach$(Iterator.scala:941)
[error] 	at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
[error] 	at scala.collection.IterableLike.foreach(IterableLike.scala:74)
[error] 	at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
[error] 	at me.shadaj.scalapy.numpy.NDArray.foreach(NDArray.scala:6)
[error] 	at scala.collection.TraversableOnce.addString(TraversableOnce.scala:362)
[error] 	at scala.collection.TraversableOnce.addString$(TraversableOnce.scala:358)
[error] 	at me.shadaj.scalapy.numpy.NDArray.addString(NDArray.scala:6)
[error] 	at scala.collection.TraversableOnce.mkString(TraversableOnce.scala:328)
[error] 	at scala.collection.TraversableOnce.mkString$(TraversableOnce.scala:327)
[error] 	at me.shadaj.scalapy.numpy.NDArray.mkString(NDArray.scala:6)
[error] 	at scala.collection.TraversableLike.toString(TraversableLike.scala:688)
[error] 	at scala.collection.TraversableLike.toString$(TraversableLike.scala:688)
[error] 	at scala.collection.SeqLike.toString(SeqLike.scala:693)
[error] 	at scala.collection.SeqLike.toString$(SeqLike.scala:693)
[error] 	at me.shadaj.scalapy.numpy.NDArray.toString(NDArray.scala:6)
[error] 	at java.base/java.lang.String.valueOf(String.java:2951)
[error] 	at java.base/java.lang.StringBuilder.append(StringBuilder.java:168)
[error] 	at scala.Tuple2.toString(Tuple2.scala:27)
[error] 	at java.base/java.lang.String.valueOf(String.java:2951)
[error] 	at java.base/java.io.PrintStream.println(PrintStream.java:897)
[error] 	at scala.Console$.println(Console.scala:271)
[error] 	at scala.Predef$.println(Predef.scala:397)
[error] 	at me.shadaj.scalapy.tensorflow.Example$.delayedEndpoint$me$shadaj$scalapy$tensorflow$Example$1(Example.scala:12)
[error] 	at me.shadaj.scalapy.tensorflow.Example$delayedInit$body.apply(Example.scala:7)
[error] 	at scala.Function0.apply$mcV$sp(Function0.scala:39)
[error] 	at scala.Function0.apply$mcV$sp$(Function0.scala:39)
[error] 	at scala.runtime.AbstractFunction0.apply$mcV$sp(AbstractFunction0.scala:17)
[error] 	at scala.App.$anonfun$main$1$adapted(App.scala:80)
[error] 	at scala.collection.immutable.List.foreach(List.scala:392)
[error] 	at scala.App.main(App.scala:80)
[error] 	at scala.App.main$(App.scala:78)
[error] 	at me.shadaj.scalapy.tensorflow.Example$.main(Example.scala:7)
[error] 	at me.shadaj.scalapy.tensorflow.Example.main(Example.scala)

Basically I'm going at it a bit blindly and would appreciate any tips :D

BTW should Keras have a reference in TensorFlow or should it be top-level?

@shadaj
Copy link
Owner Author

shadaj commented Mar 28, 2020

Ah, this is a bug with the old Jep backend for ScalaPy that was fixed in 0.3.0+17-2bfe86de. For now, you should be able to just upgrade to that version. There will likely be a full release soon, so before we merge in we can upgrade to that.

I think Keras should probably be top-level, since AFAIK most developers import it separately from regular TensorFlow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants