python - 如何解决 NeuroDiffEq 中的错误“mat1 和 mat2 形状不能相乘(1000x1 和 3x512)”?
问题描述
我是神经网络的新手,对它们的使用方式有基本的了解。我正在尝试使用人工神经网络(ANN),特别是使用 NeuroDiffEq 包来解决具有边界条件的球面拉普拉斯方程:u(r=0)=u(r=1)=0 对于所有 theta 和 phi Python。以下是相同的代码
import numpy as np
import matplotlib.pyplot as plt
import torch
from neurodiffeq import diff
from neurodiffeq.networks import FCNN
from neurodiffeq.conditions import DirichletBVPSpherical
from neurodiffeq.solvers import SolverSpherical
from neurodiffeq.monitors import MonitorSpherical
from neurodiffeq.generators import Generator3D
%matplotlib notebook
laplace = lambda u, r, theta, phi: [
diff(((r**2)*diff(u,r,order=1)), r, order=1)/r**2 +
diff((np.sin(theta))*diff(u,theta,order=1), theta, order=1)/((r**2)*(np.sin(theta))) +
diff(u,phi,order=2)/(r*np.sin(theta))**2
]
conditions = [
DirichletBVPSpherical(r_0=0.0,f=0.0,r_1=1.0,g=0.0)
]
nets = [
FCNN(n_input_units=3, n_output_units=1, hidden_units=[512]),
]
monitor=MonitorSpherical(r_min=0.0,r_max=1.0,check_every=10,shape=(10,10,10),r_scale='linear',theta_min=0,theta_max=np.pi,phi_min=0,phi_max=2*np.pi)
monitor_callback = monitor.to_callback()
solver = SolverSpherical(
pde_system=laplace,
conditions=conditions,
r_min=0.0,
r_max=1.0,
nets=nets,
train_generator=Generator3D(grid=(10, 10, 10), xyz_min=(0.0, 0.0, 0.0), xyz_max=(1.0, 1.0, 1.0), method='equally-spaced'),
valid_generator=Generator3D(grid=(10, 10, 10), xyz_min=(0.0, 0.0, 0.0), xyz_max=(1.0, 1.0, 1.0), method='equally-spaced-noisy'),
)
solver.fit(max_epochs=200, callbacks=[monitor_callback])
solution_neural_net_laplace = solver.get_solution()
我收到以下错误
mat1 and mat2 shapes cannot be multiplied (1000x1 and 3x512)
对于解决此错误的任何帮助,我将不胜感激。提前致谢!
解决方案
问题是与 相乘的形状mat1
不正确mat2
。可能您使用的是 10x10x10 = 1000 的网格,因此请尝试将其设为其他内容,即 8x8x8 = 512,或者您可以尝试将输入单位设为 1000,看看是否能解决问题。
也可能是n_input_units = 512
或n_input_units = 1000
,n_hidden_units = [something else]
(取决于您在网格中所做的更改)
推荐阅读
- javascript - js 'click' 功能/EventListener 问题
- c# - C#:如何在更改另一个下拉列表时更改下拉列表的内容?
- azure - Azure PostgreSQL 是否支持自定义证书而不是用户名和密码?
- arrays - Typescript:字符串联合和具有相同项目类型的完整数组的DRY定义
- sql-server - 如何找到分布在 SQL Server 中多行的数据的第一个和最后一个数据点
- javascript - bootstrap 4的轮播的实现不起作用
- reactjs - 来自 StyleSheet 的 React Font Awesome 样式不起作用
- sql - 创建所需的脚本
- reactjs - 带有反应模板的 Asp.net 核心 Web 应用程序:在开发中,服务器首先检查 mvc 路由,但在生产中服务器仅返回 index.html
- ios - 从 tableviewcel 触发 performSegue 在 iOS14 中不再工作