Мне нужно решить некоторое обыкновенное дифференциальное уравнение dy/dx = f(x) = x^2 ln(x)
, и для продолжения я создаю массив xpt между пределами 0. <= xpt <= 2. Поскольку при xpt = 0 нужно быть осторожным, я определил функцию следующим образом:
def f(x):
if x <= 1.e-6:
return 0.
else:
return np.square(x)*np.log(x)
Моя вызывающая программа гласит:
Npt = 200
xpt = np.linspace(0.,2.,Npt)
fpt = np.zeros(Npt)
Однако когда я вызываю fpt = f(xpt), я получаю ошибку:
"ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"
Я могу обойти это, введя цикл for и написав
for ip in range(Npt):
fpt[ip] = f(xpt[ip])
Но это кажется хаком и неудовлетворительным.
Я попытался рассмотреть предложение об использовании a.any()
и переопределил функцию как
def Newf(x):
if ((x <= 1.e-6).all()):
return 0.
else:
return np.square(x*np.log(x))
Но это, похоже, дает f(0.) как NaN.
Любая помощь о том, как действовать, с благодарностью принимается.
Попробуйте этот код
import numpy as np
def f(x):
return np.where(x <= 1.e-6, 0., np.square(x) * np.log(x))
Npt = 200
xpt = np.linspace(0., 2., Npt)
fpt = f(xpt)
print(fpt)
np.where(condition, value_if_true, value_if_false) применяет условие поэлементно. Если условие истинно, используется value_if_true, в противном случае — value_if_false.
Вы написали функцию f так, как будто она получает одно значение, передавая ей массив. Я думаю, вы также неправильно поняли использование функций .all() и .any() в массивах np (и других итерируемых объектах в Python). Функция .any() возвращает True, если какой-либо из элементов массива имеет значения, которые можно считать «True», в случае числа она проверяет наличие элемента, который не равен 0 в массиве. Функция .all() действует аналогичным образом, только она гарантирует, что ВСЕ элементы имеют значения, которые можно интерпретировать как «Истина».
Возможное решение — изменить функцию для обработки массивов и вернуть измененный массив. Не уверен, что вы планировали сделать в своем коде позже, но я предполагаю, что вы хотели проверить, является ли f(xpt) нулями. В этом случае это должно работать:
import numpy as np
def f(array):
return np.where(array <= 1.e-6, 0., np.square(array) * np.log(array))
def main():
Npt = 200
xpt = np.linspace(0., 2., Npt)
print(not any(f(xpt)))
if __name__ == '__main__':
main()
scipy может решать дифференциальные уравнения:
import numpy as np
import scipy as sp
def f(y,x):
if x <= 1e-6:
return 0.
else:
return np.square(x)*np.log(x)
Npt = 200
C = 0.
xpt = np.linspace(0.,2.,Npt)
ypt = sp.integrate.odeint(f, C, xpt)
print(np.hstack((np.atleast_2d(xpt).T,ypt)))
Отредактирована функция f для использования np.where, вдохновленная ответом Кена.