Scala反射中的Tensorflow [英] Tensorflow in Scala reflection
问题描述
我正在尝试让 tensorflow
让 java 在 Scala 上工作.我使用 tensorflow java 库,没有任何 Scala 包装器.
I am trying to get tensorflow
for java to work on Scala. I am use the tensorflow java library without any wrapper for Scala.
在 sbt
我有:
如果我运行 HelloWord
发现这里,它工作很好,使用 Scala 改编版:
If I run the HelloWord
found here, it WORKS fine, with the Scala adaptations:
import org.tensorflow.Graph
import org.tensorflow.Session
import org.tensorflow.Tensor
import org.tensorflow.TensorFlow
val g = new Graph()
val value = "Hello from " + TensorFlow.version()
val t = Tensor.create(value.getBytes("UTF-8"))
// The Java API doesn't yet include convenience functions for adding operations.
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
val s = new Session(g)
val output = s.runner().fetch("MyConst").run().get(0)
但是,如果我尝试使用 Scala 反射从字符串编译函数,它不起作用.这是我用来运行的代码段:
However, if I try to use Scala reflection to compile the function from a string, it DOES NOT WORK. Here is the snippet I used to run:
import scala.reflect.runtime.{universe => ru}
import scala.tools.reflect.ToolBox
val fnStr = """
{() =>
import org.tensorflow.Graph
import org.tensorflow.Session
import org.tensorflow.Tensor
import org.tensorflow.TensorFlow
val g = new Graph()
val value = "Hello from " + TensorFlow.version()
val t = Tensor.create(value.getBytes("UTF-8"))
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
val s = new Session(g)
s.runner().fetch("MyConst").run().get(0)
}
"""
val mirror = ru.runtimeMirror(getClass.getClassLoader)
val tb = mirror.mkToolBox()
var t = tb.parse(fnStr)
val fn = tb.eval(t).asInstanceOf[() => Any]
// and finally, executing the function
fn()
这里简化了 build.sbt
来重现上面的错误:
Here simplified build.sbt
to reproduce the error above:
lazy val commonSettings = Seq(
scalaVersion := "2.12.10",
libraryDependencies ++= {
Seq(
// To support runtime compilation
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"org.scala-lang" % "scala-compiler" % scalaVersion.value,
// for tensorflow4java
"org.tensorflow" % "tensorflow" % "1.15.0",
"org.tensorflow" % "proto" % "1.15.0",
"org.tensorflow" % "libtensorflow_jni" % "1.15.0"
)
}
)
lazy val `test-proj` = project
.in(file("."))
.settings(commonSettings)
当运行上述代码时,例如使用 sbt 控制台
,我收到以下错误和堆栈跟踪:
When running the above, for example with sbt console
, I get the following error and stack trace:
java.lang.NoSuchMethodError: org.tensorflow.Session.runner()Lorg/tensorflow/Session$$Runner;
at __wrapper$1$f093d26a3c504d4381a37ef78b6c3d54.__wrapper$1$f093d26a3c504d4381a37ef78b6c3d54$.$anonfun$wrapper$1(<no source file>:15)
请忽略前面代码给出的没有使用资源上下文(到 close())的内存泄漏
推荐答案
问题在于这个 bug 出现在反射编译和 Scala-Java 互操作的组合中
The thing is in this bug appearing in combination of reflective compilation and Scala-Java interop
https://github.com/scala/bug/issues/8956
Toolbox 无法对路径相关类型 (s.Runner
) 的值 (s.runner()
) 进行类型检查,如果该类型来自 Java 非静态内部类.而 Runner
是 正是org.tensorflow.Session
中的此类.
Toolbox can't typecheck a value (s.runner()
) of path-dependent type (s.Runner
) if this type comes from Java non-static inner class. And Runner
is exactly such class inside org.tensorflow.Session
.
您可以手动运行编译器(类似于 如何 Toolbox 运行它)
You can run the compiler manually (similarly to how Toolbox runs it)
import org.tensorflow.Tensor
import scala.reflect.internal.util.{AbstractFileClassLoader, BatchSourceFile}
import scala.reflect.io.{AbstractFile, VirtualDirectory}
import scala.reflect.runtime
import scala.reflect.runtime.universe
import scala.reflect.runtime.universe._
import scala.tools.nsc.{Global, Settings}
val code: String =
"""
|import org.tensorflow.Graph
|import org.tensorflow.Session
|import org.tensorflow.Tensor
|import org.tensorflow.TensorFlow
|
|object Main {
| def foo() = () => {
| val g = new Graph()
| val value = "Hello from " + TensorFlow.version()
| val t = Tensor.create(value.getBytes("UTF-8"))
| g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
|
| val s = new Session(g)
|
| s.runner().fetch("MyConst").run().get(0)
| }
|}
""".stripMargin
val directory = new VirtualDirectory("(memory)", None)
val runtimeMirror = createRuntimeMirror(directory, runtime.currentMirror)
compileCode(code, List(), directory)
val tensor = runObjectMethod("Main", runtimeMirror, "foo").asInstanceOf[() => Tensor[_]]
tensor() // STRING tensor with shape []
def compileCode(code: String, classpathDirectories: List[AbstractFile], outputDirectory: AbstractFile): Unit = {
val settings = new Settings
classpathDirectories.foreach(dir => settings.classpath.prepend(dir.toString))
settings.outputDirs.setSingleOutput(outputDirectory)
settings.usejavacp.value = true
val global = new Global(settings)
(new global.Run).compileSources(List(new BatchSourceFile("(inline)", code)))
}
def runObjectMethod(objectName: String, runtimeMirror: Mirror, methodName: String, arguments: Any*): Any = {
val objectSymbol = runtimeMirror.staticModule(objectName)
val objectModuleMirror = runtimeMirror.reflectModule(objectSymbol)
val objectInstance = objectModuleMirror.instance
val objectType = objectSymbol.typeSignature
val methodSymbol = objectType.decl(TermName(methodName)).asMethod
val objectInstanceMirror = runtimeMirror.reflect(objectInstance)
val methodMirror = objectInstanceMirror.reflectMethod(methodSymbol)
methodMirror(arguments: _*)
}
def createRuntimeMirror(directory: AbstractFile, parentMirror: Mirror): Mirror = {
val classLoader = new AbstractFileClassLoader(directory, parentMirror.classLoader)
universe.runtimeMirror(classLoader)
}
如何评估代码使用 InterfaceStability 注释(因涉及类 InterfaceStability 的非法循环引用而失败")?
这篇关于Scala反射中的Tensorflow的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!