Я запутался в значении оценки векторных якобианских произведений, когда вектор, используемый для VJP, является неидентичным вектором-строкой. Мой вопрос относится к функциям с векторным значением, а не к скалярным функциям, таким как потери. Я покажу конкретный пример с использованием Python и JAX, но это очень общий вопрос об автоматическом дифференцировании в обратном режиме.
Рассмотрим эту простую вектор-функцию, для которой якобиан легко записать аналитически:
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import vjp, jacrev
# Define a vector-valued function (3 inputs --> 2 outputs)
def vector_func(args):
x,y,z = args
a = 2*x**2 + 3*y**2 + 4*z**2
b = 4*x*y*z
return jnp.array([a, b])
# Define the inputs
x = 2.0
y = 3.0
z = 4.0
# Compute the vector-Jacobian product at the fiducial input point (x,y,z)
val, func_vjp = vjp(vector_func, (x, y, z))
print(val)
# [99,96]
# now evaluate the function returned by vjp along with basis row vectors to pull out gradient of 1st and 2nd output components
v1 = jnp.array([1.0, 0.0]) # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., first row of Jacobian
v2 = jnp.array([0.0, 1.0]) # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., second row of Jacobian
gradient1 = func_vjp(v1)
print(gradient1)
# [8, 18, 32]
gradient2 = func_vjp(v2)
print(gradient2)
# [48,32,24]
Для меня это имеет смысл - мы отдельно передаем [1,0] и [0,1] в vjp_func, чтобы соответственно получить первую и вторую строки якобиана, оцененные в нашей реперной точке (x, y, z) = (2,3,4).
Но что, если мы скормим vjp_func неидентичный вектор-строку, например [2,0]? Это спрашивает, как нужно возмущать реперную точку (x, y, z), чтобы удвоить первый компонент вывода? Если да, то есть ли способ увидеть это, оценив vector_func по возмущенным значениям параметров?
Я пробовал, но не уверен:
# suppose I want to know what perturbations in (x,y,z) cause a doubling of the first output and no change in second output component
print(func_vjp(jnp.array([2.0,0.0])))
# [16,36,64]
### Attempts to use the output of vjp_func to verify that val becomes [99*2, 96]
### none of these work
print(vector_func([16,36,64]))
# [20784, 147456]
print(vector_func([x*16,y*36,z*64])
# [299184., 3538944.]
Что я делаю неправильно, используя выходные данные func_vjp для изменения реперных параметров (x, y, z) и передачи их обратно в vector_func, чтобы убедиться, что эти возмущения параметров действительно удваивают первый компонент исходного вывода и оставляют второй компонент неизменным ?
Хорошо, спасибо, но я все еще не понимаю, как использовать вывод vjp_func([2,0]) для изменения параметров (x,y,z) и передать их обратно в vector_func, чтобы убедиться, что первый компонент удваивается. Я отредактировал свой вопрос, чтобы упростить его и сделать более конкретным. Заранее благодарим вас за любую помощь, которую вы можете предоставить!
[2, 0]
не означает, что компонент будет дублироваться; это вектор, определяющий направление касательной (т. е. градиента), который отображается обратно на соответствующую касательную во входном пространстве.
Я думаю, что в своем вопросе вы путаете простые и касательные векторные пространства. Функция vector_func
— это нелинейная функция, которая отображает вектор во входном пространстве основных векторов (представленный (x, y, z)
) в вектор в выходном пространстве основных векторов (представленный val
в вашем коде).
Функция func_vjp
— это линейная функция, которая отображает вектор в выходном касательном векторном пространстве (представленном array([2, 0])
в вашем вопросе) на вектор во входном касательном векторном пространстве ([16,36,64]
в вашем вопросе).
По построению касательные векторы в этих преобразованиях представляют градиенты входной функции при заданных основных значениях. То есть, если вы бесконечно мало возмущаете исходную прямую вдоль направления вашей выходной касательной, это соответствует бесконечно малому возмущению входной прямой вдоль направления входной касательной.
Если вы хотите проверить значения, вы можете сделать что-то вроде этого:
input_primal = (x, y, z)
output_primal, func_vjp = vjp(vector_func, input_primal)
epsilon = 1E-8 # note: small value so we're near the linear regime
output_tangent = epsilon * jnp.array([0.0, 1.0])
input_tangent, = func_vjp(output_tangent)
# Compute the perturbed output given the perturbed input
perturbed_input = [p + t for p, t in zip(input_primal, input_tangent)]
perturbed_output_1 = vector_func(perturbed_input)
print(perturbed_output_1)
# [99.00001728 96.00003904]
# Perturb the output directly
perturbed_output_2 = output_primal + output_tangent
print(perturbed_output_2)
# [99. 96.00000001]
Обратите внимание, что результаты не совпадают точно, потому что VJP действителен в локально-линейном пределе, а ваша функция очень нелинейна. Но, надеюсь, это поможет прояснить, что означают эти значения прямого и касательного в контексте вычисления VJP. Математически, если бы мы вычислили это в пределе, где эпсилон стремится к нулю, результаты точно совпали бы — вычисления градиента связаны с такими бесконечно малыми пределами.
Спасибо! Трудно сравнивать perturbed_output_1 и perturbed_output_2, учитывая, что они очень похожи. Если бы мы использовали линейную функцию, можем ли мы обойтись большим эпсилоном и получить согласие perturbed_output_1 и perturbed_output_2 для этого большего возмущения? Например, vector_func = lambda args: jnp.array([2*args[0]+3*args[1], 4*args[2]])
с эпсилон = 0,1 и тем же input_primal = (2,3,4), что и раньше, я получаю output_primal = [13,16] и perturbed_output_2 = [13,16,1], как и ожидалось, но perturbed_output_1 = [13, 17,6], что неправильно.
Да, вы правы, в моем коде отсутствуют некоторые коэффициенты масштабирования. Но суть ответа (относительно отображений между первичными и касательными пространствами) заключается в том, как вы должны думать об этих вещах.
Ну ладно, спасибо, я думал, что схожу с ума. Было бы любопытно увидеть недостающие коэффициенты масштабирования и то, как вы их получили (если это просто). Итак, просто чтобы убедиться, что я понимаю: у меня сложилось (ошибочное) впечатление, что VJP может помочь найти возмущения во входных данных, которые воспроизводят желаемое изменение в выходных данных. Но это не так. Вместо этого мы просто сообщаем VJP, в каком НАПРАВЛЕНИИ мы хотим, чтобы выходные данные были возмущены, и VJP даст нам входные возмущения, которые будут двигаться в этом направлении. Какая практическая польза от этого? Делать много маленьких последовательных итераций?
Извините, что беспокою вас @jakevdp, но если у вас есть возможность, я хотел бы кратко услышать, какие коэффициенты масштабирования отсутствуют в вашем коде, так что для линейной функции мы можем использовать больший эпсилон и по-прежнему иметь согласие perturbed_output_1 и perturbed_output_2 (например, , линейная функция в моем первом комментарии выше).
Я не знаю - я подозреваю, что это связано с определителем якобиана или чем-то еще. Я не думаю, что вам когда-либо понадобится напрямую использовать такие возмущения; Я хотел, чтобы мой ответ в целом объяснил, как думать о первичных и касательных пространствах.
Я думаю, вы сами ответили на свой вопрос в последнем абзаце: VJP распространяют выходные возмущения на входные возмущения, и эти выходные возмущения в целом не будут единичными возмущениями вдоль одного выходного значения изолированно. Эти возмущения могут быть составлены из любого сочетания значений в каждом выходном измерении.