首页 > 解决方案 > 为什么 tensorflow 图形文件无法在 java 中将图像作为提要?

问题描述

代码 :

public class tensorflow_java {
  public static void main(String[] args)throws IOException{
      String modelDir = "/home/shorav/tensorflow_java/";

      String imageFile = "/home/shorav/tensorflow_java/2.png";

      byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "frozen_inference_graph.pb"));

      File input = new File(imageFile);

      BufferedImage image = ImageIO.read(input);

      int width= image.getWidth();

      int height= image.getHeight();

      float div = Math.max(width,height); 

          float ratio2 = 513 / div ; 

      BufferedImage resized = resize(image,(int)(ratio2*width),(int)(ratio2*height));

      File output = new File("2.png");

      ImageIO.write(resized, "png", output);

      File input2 = new File("2.png");

          BufferedImage image2 = ImageIO.read(input2);
      byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));

      Tensor<Float> image3 = constructAndExecuteGraphToNormalizeImage(imageBytes);

          System.out.println("I created tensor");
      Graph g = new Graph();
          g.importGraphDef(graphDef);
          Session s = new Session(g);
          Tensor result =s.runner().feed("ImageTensor", image3).fetch("SemanticPredictions").run().get(0);

  }
private static byte[] readAllBytesOrExit(Path path) {
    try {
      return Files.readAllBytes(path);
    } catch (IOException e) {
      System.err.println("Failed to read");
      System.exit(1);
    }
    return null;   
}


private static BufferedImage resize(BufferedImage img, int height, int width) {
        Image tmp = img.getScaledInstance(width, height, Image.SCALE_SMOOTH);
        BufferedImage resized = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
        Graphics2D g2d = resized.createGraphics();
        g2d.drawImage(tmp, 0, 0, null);
        g2d.dispose();
        return resized;
    }


private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
    try (Graph g = new Graph()) {
      GraphBuilder b = new GraphBuilder(g);

      final int H = 513;
      final int W = 513;
      final float mean = 117f;
      final float scale = 1f;
      final Output<String> input = b.constant("input", imageBytes);
      final Output<Float> output =
          b.div(
              b.sub(
                  b.resizeBilinear(
                      b.expandDims(
                          b.cast(b.decodeJpeg(input, 3), Float.class),
                          b.constant("make_batch", 0)),
                      b.constant("size", new int[] {H, W})),
                  b.constant("mean", mean)),
              b.constant("scale", scale));
      try (Session s = new Session(g)) {

        return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
      }
    }
}


}

标签: javatensorflowdeep-learningjava-native-interface

解决方案


推荐阅读