Python и проблемы с передачей массива в функцию

Мне нужно решить некоторое обыкновенное дифференциальное уравнение dy/dx = f(x) = x^2 ln(x), и для продолжения я создаю массив xpt между пределами 0. <= xpt <= 2. Поскольку при xpt = 0 нужно быть осторожным, я определил функцию следующим образом:

def f(x):
    if x <= 1.e-6:
        return 0.
    else:
        return np.square(x)*np.log(x)

Моя вызывающая программа гласит:

Npt = 200

xpt = np.linspace(0.,2.,Npt)

fpt = np.zeros(Npt)

Однако когда я вызываю fpt = f(xpt), я получаю ошибку:

"ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"

Я могу обойти это, введя цикл for и написав

for ip in range(Npt):
    fpt[ip] = f(xpt[ip])

Но это кажется хаком и неудовлетворительным.

Я попытался рассмотреть предложение об использовании a.any() и переопределил функцию как

def Newf(x):

    if ((x <= 1.e-6).all()):
        return 0.
    else:
        return np.square(x*np.log(x))

Но это, похоже, дает f(0.) как NaN.

Любая помощь о том, как действовать, с благодарностью принимается.

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
65
3
Перейти к ответу Данный вопрос помечен как решенный

Ответы 3

Ответ принят как подходящий

Попробуйте этот код

import numpy as np

def f(x):
    return np.where(x <= 1.e-6, 0., np.square(x) * np.log(x))

Npt = 200
xpt = np.linspace(0., 2., Npt)
fpt = f(xpt)

print(fpt)

np.where(condition, value_if_true, value_if_false) применяет условие поэлементно. Если условие истинно, используется value_if_true, в противном случае — value_if_false.

Вы написали функцию f так, как будто она получает одно значение, передавая ей массив. Я думаю, вы также неправильно поняли использование функций .all() и .any() в массивах np (и других итерируемых объектах в Python). Функция .any() возвращает True, если какой-либо из элементов массива имеет значения, которые можно считать «True», в случае числа она проверяет наличие элемента, который не равен 0 в массиве. Функция .all() действует аналогичным образом, только она гарантирует, что ВСЕ элементы имеют значения, которые можно интерпретировать как «Истина».

Возможное решение — изменить функцию для обработки массивов и вернуть измененный массив. Не уверен, что вы планировали сделать в своем коде позже, но я предполагаю, что вы хотели проверить, является ли f(xpt) нулями. В этом случае это должно работать:

import numpy as np


def f(array):
    return np.where(array <= 1.e-6, 0., np.square(array) * np.log(array))


def main():
    Npt = 200
    xpt = np.linspace(0., 2., Npt)
    print(not any(f(xpt)))


if __name__ == '__main__':
    main()

Отредактирована функция f для использования np.where, вдохновленная ответом Кена.

Eli Alkhazov 23.07.2024 21:56

scipy может решать дифференциальные уравнения:

import numpy as np
import scipy as sp

def f(y,x):
    if x <= 1e-6:
        return 0.
    else:
        return np.square(x)*np.log(x)

Npt = 200
C = 0.
xpt = np.linspace(0.,2.,Npt)
ypt = sp.integrate.odeint(f, C, xpt)
print(np.hstack((np.atleast_2d(xpt).T,ypt)))

Другие вопросы по теме