python - keras 模型上奇怪的分析结果:越复杂越快
问题描述
所以我目前正在尝试找出最适合处理傅立叶变换的深度学习框架。到目前为止,我使用keras
的是tensorflow
后端,但我注意到 fft 有点慢(参见Github 上的这个问题)。
所以最近我尝试直接将速度与pytorch
. 由于我想做的不仅仅是简单地进行傅立叶变换,因此我尝试添加一些操作来进行更全面的基准测试,我注意到对于keras
,添加操作正在减少计算时间。
这是最小的工作示例(基本上是在 2D 中进行逆傅立叶变换,通过获取图像的模块完成,并且介于潜在的“去复杂化”和“重新复杂化”之间):
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
from keras.layers import Input, Lambda, concatenate
from keras.models import Model
import numpy as np
import tensorflow as tf
from tensorflow.signal import ifft2d
import time
def concatenate_real_imag(x):
x_real = Lambda(tf.math.real)(x)
x_imag = Lambda(tf.math.imag)(x)
return concatenate([x_real, x_imag])
def to_complex(x):
return tf.complex(x[0], x[1])
def complex_from_half(x, n, output_shape):
return Lambda(lambda x: to_complex([x[..., :n], x[..., n:]]), output_shape=output_shape)(x)
def weird_model(conc_then_com=False):
input_size = (320, None, 1)
kspace_input = Input(input_size, dtype='complex64', name='kspace_input')
inv_kspace = Lambda(ifft2d, output_shape=input_size)(kspace_input)
if conc_then_com:
inv_kspace = concatenate_real_imag(kspace_input)
inv_kspace = complex_from_half(inv_kspace, 1, input_size)
abs_inv_kspace = Lambda(tf.math.abs)(inv_kspace)
model = Model(inputs=kspace_input, outputs=abs_inv_kspace)
model.compile(
optimizer='adam',
loss='mse',
)
return model
# fake data
data_x = np.random.rand(35, 320, 320, 1) + 1j * np.random.rand(35, 320, 320, 1)
data_y = np.random.rand(35, 320, 320, 1)
start = time.time()
r = weird_model(conc_then_com=True).predict_on_batch(data_x)
end = time.time()
duration = end - start
print(f'For the prediction with the complex model it took {duration}')
start = time.time()
r = weird_model(conc_then_com=False).predict_on_batch(data_x)
end = time.time()
duration = end - start
print(f'For the prediction with the simple model it took {duration}')
start = time.time()
weird_model(conc_then_com=True).fit(
x=data_x,
y=data_y,
batch_size=35,
epochs=1,
verbose=2,
shuffle=False,
)
end = time.time()
duration = end - start
print(f'For the fitting with the complex model it took {duration}')
start = time.time()
weird_model(conc_then_com=False).fit(
x=data_x,
y=data_y,
batch_size=35,
epochs=1,
verbose=2,
shuffle=False,
)
end = time.time()
duration = end - start
print(f'For the fitting with the simple model it took {duration}')
这给出了以下时间(或多或少):
For the prediction with the complex model it took 0.24
For the prediction with the simple model it took 3.98
For the fitting with the complex model it took 0.28
For the fitting with the simple model it took 4.01
我不知道发生了什么。
解决方案
实际上,这只是一个错字:
inv_kspace = concatenate_real_imag(kspace_input)
应该是inv_kspace = concatenate_real_imag(inv_kspace)
推荐阅读
- sql - postgresql中的concat函数
- javascript - javascript错误乘以布局
- javascript - Javascript 不适用于 Owl-Carousel
- python - 使用 python 在 Web 浏览器中记录用户活动
- c# - 使用属性在每个节点处进行反向递归以生成 XML 面包屑的最佳方法是什么?
- git - 如何:在推送/合并到 master 时,将特定文件提交/推送到另一个项目
- java - 没有代理的 RestTemplate 调用
- ios - 带有子视图的 MapKit 地图不会添加注释
- java - 将 onClick 添加到 RecyclerView 项目的一部分
- python - 根据来自其他数据框的关系创建新的数据框