首页 > 解决方案 > “数据计数”iOS的Tensorflow解释器抛出错误

问题描述

我正在使用 TensorFlowLiteSwift,并且我正在使用的模型负责在将图像裁剪为梯形形状时展平图像。现在,Tensorflow 并没有提供太多的文档。所以,我一直在尝试从他们的示例项目中实现一些东西。

但这里有一个问题,它抛出错误说“提供的数据计数必须匹配所需的计数”并且所需的计数是 4。我byteCount在 Interpreter.swift 中回溯但找不到实际的设置器。

那么,.tflite模型是否负责“所需计数”?如果不是,那么如何设置?

这是我认为有助于理解我的问题的一段代码:

/// Performs image preprocessing, invokes the `Interpreter`, and processes the inference results.
    func runModel(on item: ImageProcessInfo) -> UIImage? {
        let rgbData = item.resizedImage.scaledData(with: CGSize(width: 1000, height: 900),
                                                   byteCount: inputWidth * inputHeight
                                                   * batchSize,
                                                   isQuantized: false)
        
        var corner = item.corners.map { $0.map { p -> (Float, Float) in
            return (Float(p.x), Float(p.y))
            } }
        var item = item
        
        guard let height = NSMutableData(capacity: 0) else { return nil }
        height.append(&item.originalHeight, length: 4)
        
        guard let width = NSMutableData(capacity: 0) else { return nil }
        width.append(&item.originalWidth, length: 4)
        
        guard let corners = NSMutableData(capacity: 0) else { return nil }
        corners.append(&corner, length: 4)
        
        do {
            try interpreter.copy(rgbData!, toInputAt: 0)
            try interpreter.copy(height as Data, toInputAt: 1)
            try interpreter.copy(width as Data, toInputAt: 2)
            try interpreter.copy(corners as Data, toInputAt: 3)
            try interpreter.invoke()
            
            let outputTensor1 = try self.interpreter.output(at: 0)
            
            guard let cgImage = postprocessImageData(data: outputTensor1.data, size: CGSize(width: 1000, height: 900)) else {
                return nil
            }
            
            let outputImage = UIImage(cgImage: cgImage)
            return outputImage
            
        } catch {
            dump(error)
            return nil
        }
    }

extension UIImage {
    func scaledData(with size: CGSize, byteCount: Int, isQuantized: Bool) -> Data? {
      guard let cgImage = self.cgImage, cgImage.width > 0, cgImage.height > 0 else { return nil }
      guard let imageData = imageData(from: cgImage, with: size) else { return nil }
      var scaledBytes = [UInt8](repeating: 0, count: byteCount)
      var index = 0
      for component in imageData.enumerated() {
        let offset = component.offset
        let isAlphaComponent = (offset % 4)
          == 3
        guard !isAlphaComponent else { continue }
        scaledBytes[index] = component.element
        index += 1
      }
      if isQuantized { return Data(scaledBytes) }
      let scaledFloats = scaledBytes.map { (Float32($0) - 127.5) / 127.5 }
      return Data(copyingBufferOf: scaledFloats)
    }

private func imageData(from cgImage: CGImage, with size: CGSize) -> Data? {
      let bitmapInfo = CGBitmapInfo(
        rawValue: CGBitmapInfo.byteOrder32Big.rawValue | CGImageAlphaInfo.premultipliedLast.rawValue
      )
      let width = Int(size.width)
      let scaledBytesPerRow = (cgImage.bytesPerRow / cgImage.width) * width
      guard let context = CGContext(
          data: nil,
          width: width,
          height: Int(size.height),
          bitsPerComponent: cgImage.bitsPerComponent,
          bytesPerRow: scaledBytesPerRow,
          space: CGColorSpaceCreateDeviceRGB(),
          bitmapInfo: bitmapInfo.rawValue)
      else {
        return nil
      }
      context.draw(cgImage, in: CGRect(origin: .zero, size: size))
      return context.makeImage()?.dataProvider?.data as Data?
    }
}

@discardableResult
  public func copy(_ data: Data, toInputAt index: Int) throws -> Tensor {
    let maxIndex = inputTensorCount - 1
    guard case 0...maxIndex = index else {
      throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
    }
    guard let cTensor = TfLiteInterpreterGetInputTensor(cInterpreter, Int32(index)) else {
      throw InterpreterError.allocateTensorsRequired
    }

    /* Error here */
    let byteCount = TfLiteTensorByteSize(cTensor)
    guard data.count == byteCount else {
      throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
    }

    #if swift(>=5.0)
      let status = data.withUnsafeBytes {
        TfLiteTensorCopyFromBuffer(cTensor, $0.baseAddress, data.count)
      }
    #else
      let status = data.withUnsafeBytes { TfLiteTensorCopyFromBuffer(cTensor, $0, data.count) }
    #endif  // swift(>=5.0)
    guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
    return try input(at: index)
  }

标签: iosswifttensorflowtensorflow-lite

解决方案


输入形状是什么?你能确定哪个在抱怨尺寸吗?

乍一看,corners.append(&corner, length: 4)似乎很奇怪——确实corners只包含 1 Float(字节大小为 4)?

张量的byteCountfor a 由底层 C API 填充,并简单地返回在模型加载阶段填充tensor->bytes的底层TfLiteTensor结构。


推荐阅读