java - 在 scala 中扩展 java 类并覆盖方法
问题描述
我正在尝试按照以下代码示例实现自定义损失函数(DL4J) : https ://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j /examples/misc/lossfunctions/CustomLossL1L2.java。
我需要扩展 ILossFunction并覆盖一些方法。
问题:
函数正在返回org.nd4j.linalg.primitives.Pair<Double, INDArray>
。
我试过 :
override def computeGradientAndScore(
labels: INDArray,
preOutput: INDArray,
activationFn: IActivation,
mask: INDArray,
average: Boolean
): Pair[Double, INDArray] = {
Pair.makePair(
computeScore(labels, preOutput, activationFn, mask, average),
computeGradient(labels, preOutput, activationFn, mask)
)
}
并得到以下编译错误:
[info] Compiling 2 Scala sources to PATH
[error] PATH/CosineSimilarity.scala:78: overriding method computeGradientAndScore in trait ILossFunction of type (x$1: org.nd4j.linalg.api.ndarray.INDArray, x$2: org.nd4j.linalg.api.ndarray.INDArray, x$3: org.nd4j.linalg.activations.IActivation, x$4: org.nd4j.linalg.api.ndarray.INDArray, x$5: Boolean)org.nd4j.linalg.primitives.Pair[Double,org.nd4j.linalg.api.ndarray.INDArray];
[error] method computeGradientAndScore has incompatible type
[error] override def computeGradientAndScore(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray, average: Boolean): Pair[Double, INDArray] = {
[error] ^
[error] one error found
[error] (root/compile:compileIncremental) Compilation failed
[error] Total time: 4 s, completed 16 mai 2018 16:45:48
问题: 如何覆盖此方法?
解决方案
以下代码
import java.lang
import org.nd4j.linalg.activations.IActivation
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.lossfunctions.ILossFunction
import org.nd4j.linalg.primitives
class MyLossFunction extends ILossFunction {
override def computeScore(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray, average: Boolean): Double = ???
override def computeScoreArray(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray): INDArray = ???
override def computeGradient(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray): INDArray = ???
override def computeGradientAndScore(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray, average: Boolean): primitives.Pair[lang.Double, INDArray] = {
primitives.Pair.makePair(
computeScore(labels, preOutput, activationFn, mask, average),
computeGradient(labels, preOutput, activationFn, mask)
)
}
override def name(): String = ???
}
编译
scalaVersion := "2.12.6"
libraryDependencies += "org.deeplearning4j" % "deeplearning4j-core" % "0.9.1"
libraryDependencies += "org.nd4j" % "nd4j-native-platform" % "0.9.1" % Test
libraryDependencies += "org.datavec" % "datavec-api" % "0.9.1"
推荐阅读
- javascript - 如何在 Zabuto 日历中禁用前几天?
- github - 使用 github 工作流进行 JSON 验证
- r - R 将 .dta 变量更改为具有范围的数字
- flutter - 如何在小部件重建之间设置动画?
- html - 两个连续的 inline-block 容器如何在网页上对齐?何时需要垂直对齐将它们彼此对齐?
- magento2 - Magento2.4中如何根据类别搜索产品?
- apache - Apache / httpd, ap_rflush() 调用但没有发送数据包
- javascript - 如何将数据推送到codeigniter中的光标位置
- python - 使用 Adobe Analytics API 检索 SummaryData 标签
- php - codeigniter 4 验证作为数组的键不起作用