Привет, насколько я понимаю инструкцию MMA с ptx (пожалуйста, скажите мне, если я ошибаюсь):
У меня есть скомпилированный набор инструкций ptx, и я выбрал одну инструкцию mma:
//
// Generated by LLVM NVPTX Back-End
//
.version 8.4
.target sm_89
.address_size 64
// .globl matmul_kernel
.extern .shared .align 16 .b8 global_smem[];
.visible .entry matmul_kernel(
.param .u64 matmul_kernel_param_0,
.param .u64 matmul_kernel_param_1,
.param .u64 matmul_kernel_param_2,
.param .u32 matmul_kernel_param_3,
.param .u32 matmul_kernel_param_4,
.param .u32 matmul_kernel_param_5,
.param .u32 matmul_kernel_param_6,
.param .u32 matmul_kernel_param_7,
.param .u32 matmul_kernel_param_8
)
.maxntid 128, 1, 1
{
...
ldmatrix.sync.aligned.m8n8.x4.shared.b16 { %r3100, %r3101, %r3102, %r3103 }, [ %r561 + 0 ];
...
ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 { %r3084, %r3085, %r3086, %r3087 }, [ %r581 + 0 ];
...
mov.f32 %f2306, 0f00000000;
mov.b32 %r3107, 2;
mov.b32 %r3106, 0;
shl.b32 %r2885, %r100, 1;
shl.b32 %r2894, %r101, 1;
shl.b32 %r2895, %r102, 1;
shl.b32 %r2896, %r103, 1;
shl.b32 %r2897, %r104, 1;
shl.b32 %r2898, %r105, 1;
shl.b32 %r2899, %r106, 1;
mov.u32 %r3104, %r765;
mov.u32 %r3105, %r758;
mov.f32 %f2307, %f2306;
mov.f32 %f2308, %f2306;
mov.f32 %f2309, %f2306;
...
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %f2306, %f2307, %f2308, %f2309 }, { %r3100, %r3101, %r3102, %r3103 }, { %r3084, %r3085 }, { %f2306, %f2307, %f2308, %f2309 };
}
Теперь мои вопросы:
sync
, что целевая деформация будет продолжаться до тех пор, пока матрица не завершится тензорным ядром?m16n8k64
для разреженной и .m16n8k32
для плотной формы, для m16n8k16
нет поддержки, или m16n8k64
означает, что она поддерживает любую форму, которая m<16 и n<8 и k<64?Да, для операндов шириной 32 бита (например, операнды f32, такие как C, D) каждый регистр содержит один операнд. Для операндов шириной 16 бит (например, операнды f16, такие как A, B) каждый регистр содержит два операнда. И для C, и для D для этой инструкции нам нужно всего 16x8 = 128 операндов, и они распределяются по четыре на каждый поток в варпе (4x32 = 128). Для A нам нужно 16x16 (= 256) операндов, и они распределяются по 8 на поток в варпе (8x32 = 256). Однако, поскольку эти операнды f16 хранятся по два в 32-битном регистре, нам нужно всего 128 регистров в масштабе Warp, что составляет четыре регистра на поток, как и в случае C и D. Для B нам нужны 16x8 16-битных операндов, и это работает. до двух 32-битных регистров на поток в варпе.
Тег синхронизации означает, что в предыдущей ситуации отклонения от деформации эта инструкция будет ждать, пока все потоки в деформации не достигнут точки инструкции, прежде чем отправлять инструкцию в модуль тензорного ядра. Если бы мы представили несинхронный вариант инструкции (его не существует, это просто для иллюстрации), то это было бы так, как если бы мы выполнили __syncwarp()
, за которым следует несинхронный вариант инструкции.
Нет, указанные инструкции поддерживают умножение матриц именно указанных размеров; не больше, не меньше. Например, инструкция m16n8k16 требует именно M=16, а не M<16. Это верно независимо от того, говорим ли мы о плотных или разреженных вариантах.
Плотный m16n8k16
для f16
поддерживается начиная с версии PTX 7.0, см. здесь.
Вот запись в этой таблице, которая показывает поддержку: mma Dense Floating-point - .f16 .m8n8k4 PTX ISA version 6.4 .m16n8k8 PTX ISA version 6.5 .m16n8k16 PTX ISA version 7.0
Итак, таблица показывает, что это новая введенная форма, верно? это означает, что более поздняя версия может использовать все формы, представленные в предыдущих???
Да, если у вас PTX версии 7.0 или новее, вы можете задать формы .m8n8k4
, .m16n8k8
и .m16n8k16
для этого конкретного варианта (mma
, f16
, Плотный). Если у вас был PTX 6.5 или выше, но меньше 7.0, вы могли выдавать только фигуры .m8n8k4
и .m16n8k8
. Если бы у вас был PTX 6.4, вы могли бы выдать только .m8n8k4
. Если у вас был PTX 6.3 или ниже, вы не смогли бы выпустить ни один из них.
спасибо за ваш ответ, но для пункта 3, как вы сказали, согласно документу ptx, версия 8.4 будет поддерживать только форму
m16n8k64
для разреженных иm16n8k32
для плотных. Тогда как же компилятор мог получитьm16n8k16
?