Создание подкласса Numpy Array - распространение атрибутов

Я хотел бы знать, как можно распространять настраиваемые атрибуты массивов numpy, даже если массив проходит через такие функции, как np.fromfunction.

Например, мой класс ExampleTensor определяет атрибут attr, который по умолчанию установлен на 1.

import numpy as np

class ExampleTensor(np.ndarray):
    def __new__(cls, input_array):
        return np.asarray(input_array).view(cls)

    def __array_finalize__(self, obj) -> None:
        if obj is None: return
        # This attribute should be maintained!
        self.attr = getattr(obj, 'attr', 1)

Нарезка и базовые операции между экземплярами ExampleTensor будут поддерживать атрибуты, но использование других функций numpy не будет (вероятно, потому что они создают обычные массивы numpy вместо ExampleTensors). Мой вопрос: Есть ли решение, которое сохраняет настраиваемые атрибуты, когда обычный массив numpy построен из подклассов количество экземпляров массива?

Пример воспроизведения проблемы:

ex1 = ExampleTensor([[3, 4],[5, 6]])
ex1.attr = "some val"

print(ex1[0].attr)    # correctly outputs "some val"
print((ex1+ex1).attr) # correctly outputs "some val"

np.sum([ex1, ex1], axis=0).attr # Attribute Error: 'numpy.ndarray' object has no attribute 'attr'

Посмотрите на код других подклассов, таких как np.matrix и маскированные массивы.

hpaulj 25.07.2018 16:42

Я думаю, что это довольно интересный вопрос, и ответ, который резюмирует код для подклассов, таких как np.matrix, был бы неплохим.

JE_Muc 25.07.2018 17:35
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
19
2
2 639
4
Перейти к ответу Данный вопрос помечен как решенный

Ответы 4

Думаю, ваш пример неверен:

>>> type(ex1)
<class '__main__.ExampleTensor'>

но

>>> type([ex1, ex1])
<class 'numpy.ndarray'>

для которых ваши перегруженные __new__ и __array_finalize__ не вызываются, поскольку вы фактически создаете массив, а не свой подкласс. Однако они называются, если вы это делаете:

>>> ExampleTensor([ex1, ex1])

который устанавливает attr = 1, поскольку вы не определили, как распространять атрибут при построении ExampleTensor из списка ExampleTensor. Вам нужно будет определить это поведение в своем подклассе, перегрузив соответствующие операции. Как было предложено в комментариях выше, для вдохновения стоит взглянуть на код для np.matrix.

Ответ принят как подходящий
import numpy as np

class ExampleTensor(np.ndarray):
    def __new__(cls, input_array):
        return np.asarray(input_array).view(cls)

    def __array_finalize__(self, obj) -> None:
        if obj is None: return
        # This attribute should be maintained!
        default_attributes = {"attr": 1}
        self.__dict__.update(default_attributes)  # another way to set attributes

Реализуйте метод array_ufunc следующим образом

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):  # this method is called whenever you use a ufunc
        f = {
            "reduce": ufunc.reduce,
            "accumulate": ufunc.accumulate,
            "reduceat": ufunc.reduceat,
            "outer": ufunc.outer,
            "at": ufunc.at,
            "__call__": ufunc,
        }
        output = ExampleTensor(f[method](*(i.view(np.ndarray) for i in inputs), **kwargs))  # convert the inputs to np.ndarray to prevent recursion, call the function, then cast it back as ExampleTensor
        output.__dict__ = self.__dict__  # carry forward attributes
        return output

Контрольная работа

x = ExampleTensor(np.array([1,2,3]))
x.attr = 2

y0 = np.add(x, x)
print(y0, y0.attr)
y1 = np.add.outer(x, x)
print(y1, y1.attr)  # works even if called with method

[2 4 6] 2
[[2 3 4]
 [3 4 5]
 [4 5 6]] 2

Объяснение в комментариях.

в __array_ufunc__ мне пришлось использовать 'output .__ dict __. update (self .__ dict __)' вместо output.__dict__ = self.__dict__, иначе я бы получил сообщение об ошибке ndarray не имеет атрибута __dict__

Thawn 12.02.2020 13:05

также i.view(np.ndarray) for i in inputs выдает ошибку, если ввод не является ndarray (например, с плавающей точкой или целым числом). Вместо этого используйте i.view(np.ndarray) if isinstance(i, ExampleTensor) else i for i in inputs.

Thawn 12.02.2020 13:51

Вам также нужно переопределить __array_function__ таким же образом для распространения атрибутов при использовании таких методов, как np.einsum?

Kevin 15.04.2021 23:24

Какое значение должно «размножаться», если ex1.attr != ex2.attr для np.sum([ex1, ex2], axis=0).attr?

Обратите внимание, что этот вопрос более фундаментален, чем может показаться на первый взгляд: как вообще большое количество функций numpy могло самостоятельно определить ваше намерение? Вероятно, вы не сможете избежать написания перегруженной версии для каждой из функций с учетом атрибута, например этой:

def sum(a, **kwargs):
    sa=np.sum(a, **kwargs)
    if isinstance(a[0],ExampleTensor): # or if hasattr(a[0],'attr')
        sa.attr=a[0].attr
    return sa

Я уверен, что этого недостаточно для обработки любого ввода np.sum (), но он должен работать для вашего примера.

Вот попытка, которая работает для операторов, которые не являются массивами, и даже когда наш подкласс указан как вывод numpy ufunc (пояснения в комментариях):

import numpy as np


class ArraySubclass(np.ndarray):
    '''Subclass of ndarray MUST be initialized with a numpy array as first argument.
    '''
    def __new__(cls, input_array, a=None, b=1):
        obj = np.asarray(input_array).view(cls)
        obj.a = a
        obj.b = b
        return obj

    def __array_finalize__(self, obj):
        if obj is None:  # __new__ handles instantiation
            return
        '''we essentially need to set all our attributes that are set in __new__ here again (including their default values). 
        Otherwise numpy's view-casting and new-from-template mechanisms would break our class.
        '''
        self.a = getattr(obj, 'a', None)
        self.b = getattr(obj, 'b', 1)

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):  # this method is called whenever you use a ufunc
        '''this implementation of __array_ufunc__ makes sure that all custom attributes are maintained when a ufunc operation is performed on our class.'''

        # convert inputs and outputs of class ArraySubclass to np.ndarray to prevent infinite recursion
        args = ((i.view(np.ndarray) if isinstance(i, ArraySubclass) else i) for i in inputs)
        outputs = kwargs.pop('out', None)
        if outputs:
            kwargs['out'] = tuple((o.view(np.ndarray) if isinstance(o, ArraySubclass) else o) for o in outputs)
        else:
            outputs = (None,) * ufunc.nout
        # call numpys implementation of __array_ufunc__
        results = super().__array_ufunc__(ufunc, method, *args, **kwargs)  # pylint: disable=no-member
        if results is NotImplemented:
            return NotImplemented
        if method == 'at':
            # method == 'at' means that the operation is performed in-place. Therefore, we are done.
            return
        # now we need to make sure that outputs that where specified with the 'out' argument are handled corectly:
        if ufunc.nout == 1:
            results = (results,)
        results = tuple((self._copy_attrs_to(result) if output is None else output)
                        for result, output in zip(results, outputs))
        return results[0] if len(results) == 1 else results

    def _copy_attrs_to(self, target):
        '''copies all attributes of self to the target object. target must be a (subclass of) ndarray'''
        target = target.view(ArraySubclass)
        try:
            target.__dict__.update(self.__dict__)
        except AttributeError:
            pass
        return target

и вот соответствующие юнит-тесты:

import unittest
class TestArraySubclass(unittest.TestCase):
    def setUp(self):
        self.shape = (10, 2, 5)
        self.subclass = ArraySubclass(np.zeros(self.shape))

    def test_instantiation(self):
        self.assertIsInstance(self.subclass, np.ndarray)
        self.assertIs(self.subclass.a, None)
        self.assertEqual(self.subclass.b, 1)
        self.assertEqual(self.subclass.shape, self.shape)
        self.assertTrue(np.array_equal(self.subclass, np.zeros(self.shape)))
        sub2 = micdata.arrayasubclass.ArraySubclass(np.zeros(self.shape), a=2)
        self.assertEqual(sub2.a, 2)

    def test_view_casting(self):
        self.assertIsInstance(np.zeros(self.shape).view(ArraySubclass),ArraySubclass)

    def test_new_from_template(self):
        self.subclass.a = 5
        bla = self.subclass[3, :]
        self.assertIsInstance(bla, ArraySubclass)
        self.assertIs(bla.a, 5)
        self.assertEqual(bla.b, 1)

    def test_np_min(self):
        self.assertEqual(np.min(self.subclass), 0)

    def test_ufuncs(self):
        self.subclass.b = 2
        self.subclass += 2
        self.assertTrue(np.all(self.subclass == 2))
        self.subclass = self.subclass + np.ones(self.shape)
        self.assertTrue(np.all(self.subclass == 3))
        np.multiply.at(self.subclass, slice(0, 2), 2)
        self.assertTrue(np.all(self.subclass[:2] == 6))
        self.assertTrue(np.all(self.subclass[2:] == 3))
        self.assertEqual(self.subclass.b, 2)

    def test_output(self):
        self.subclass.a = 3
        bla = np.ones(self.shape)
        bla *= 2
        np.multiply(bla, bla, out=self.subclass)
        self.assertTrue(np.all(self.subclass == 5))
        self.assertEqual(self.subclass.a, 3)

P.s. tempname123 получил почти правильно. Однако его ответ не работает для операторов, которые не являются массивами, и когда его класс указан как вывод ufunc:

>>> ExampleTensor += 1
AttributeError: 'int' object has no attribute 'view'
>>> np.multiply(np.ones((5)), np.ones((5)), out=ExampleTensor)
RecursionError: maximum recursion depth exceeded in comparison

Другие вопросы по теме