Torque3D/Engine/source/math/simd/mMath_AVX.cpp
marauder2k7 a7d92c344d basic simd math function overrides
beginning the implementation of overriding the math functions with sse2 sse41 and avx2 functions
2026-02-24 20:19:34 +00:00

112 lines
3.7 KiB
C++

#include "platform/platform.h"
#include "math/mMath.h"
#include "math/util/frustum.h"
#include <math.h> // Caution!!! Possible platform specific include
#include "math/mMathFn.h"
//################################################################
// AVX 2 Functions
//################################################################
#include <immintrin.h>
void m_point3F_bulk_dot_avx2( const F32* refVector,
const F32* dotPoints,
const U32 numPoints,
const U32 pointStride,
F32* output)
{
__m256 refX = _mm256_set1_ps(refVector[0]);
__m256 refY = _mm256_set1_ps(refVector[1]);
__m256 refZ = _mm256_set1_ps(refVector[2]);
U32 i = 0;
// Process 8 points at a time (AVX2 = 8 floats per 256-bit register)
for (; i + 7 < numPoints; i += 8)
{
// Load x, y, z components with stride
__m256 x = _mm256_set_ps(
dotPoints[(i + 7) * pointStride + 0],
dotPoints[(i + 6) * pointStride + 0],
dotPoints[(i + 5) * pointStride + 0],
dotPoints[(i + 4) * pointStride + 0],
dotPoints[(i + 3) * pointStride + 0],
dotPoints[(i + 2) * pointStride + 0],
dotPoints[(i + 1) * pointStride + 0],
dotPoints[(i + 0) * pointStride + 0]
);
__m256 y = _mm256_set_ps(
dotPoints[(i + 7) * pointStride + 1],
dotPoints[(i + 6) * pointStride + 1],
dotPoints[(i + 5) * pointStride + 1],
dotPoints[(i + 4) * pointStride + 1],
dotPoints[(i + 3) * pointStride + 1],
dotPoints[(i + 2) * pointStride + 1],
dotPoints[(i + 1) * pointStride + 1],
dotPoints[(i + 0) * pointStride + 1]
);
__m256 z = _mm256_set_ps(
dotPoints[(i + 7) * pointStride + 2],
dotPoints[(i + 6) * pointStride + 2],
dotPoints[(i + 5) * pointStride + 2],
dotPoints[(i + 4) * pointStride + 2],
dotPoints[(i + 3) * pointStride + 2],
dotPoints[(i + 2) * pointStride + 2],
dotPoints[(i + 1) * pointStride + 2],
dotPoints[(i + 0) * pointStride + 2]
);
// Multiply and accumulate: x*rx + y*ry + z*rz
__m256 dot = _mm256_mul_ps(x, refX);
dot = _mm256_fmadd_ps(y, refY, dot); // dot += y*refY
dot = _mm256_fmadd_ps(z, refZ, dot); // dot += z*refZ
// Store the results
_mm256_storeu_ps(&output[i], dot);
}
// Handle remaining points
for (; i < numPoints; i++)
{
const F32* pPoint = &dotPoints[i * pointStride];
output[i] = refVector[0] * pPoint[0] +
refVector[1] * pPoint[1] +
refVector[2] * pPoint[2];
}
}
void default_matF_x_matF_AVX2(const F32* A, const F32* B, F32* C)
{
for (int i = 0; i < 4; i++)
{
// Broadcast elements of A row
__m128 a0 = _mm_set1_ps(A[i * 4 + 0]);
__m128 a1 = _mm_set1_ps(A[i * 4 + 1]);
__m128 a2 = _mm_set1_ps(A[i * 4 + 2]);
__m128 a3 = _mm_set1_ps(A[i * 4 + 3]);
// Load columns of B (rows in memory since row-major)
__m128 b0 = _mm_loadu_ps(&B[0 * 4]); // B row 0
__m128 b1 = _mm_loadu_ps(&B[1 * 4]); // B row 1
__m128 b2 = _mm_loadu_ps(&B[2 * 4]); // B row 2
__m128 b3 = _mm_loadu_ps(&B[3 * 4]); // B row 3
// Multiply and sum
__m128 res = _mm_mul_ps(a0, b0);
res = _mm_add_ps(res, _mm_mul_ps(a1, b1));
res = _mm_add_ps(res, _mm_mul_ps(a2, b2));
res = _mm_add_ps(res, _mm_mul_ps(a3, b3));
// Store result row
_mm_storeu_ps(&C[i * 4], res);
}
}
void mInstallLibrary_AVX2()
{
m_point3F_bulk_dot = m_point3F_bulk_dot_avx2;
m_matF_x_matF = default_matF_x_matF_AVX2;
m_matF_x_matF_aligned = default_matF_x_matF_AVX2;
}