basic simd math function overrides

beginning the implementation of overriding the math functions with sse2 sse41 and avx2 functions
This commit is contained in:
marauder2k7 2026-02-24 20:19:34 +00:00
parent 2b375bfea4
commit a7d92c344d
5 changed files with 375 additions and 14 deletions

View file

@ -130,7 +130,7 @@ torqueAddSourceDirectories("windowManager" "windowManager/torque" "windowManager
torqueAddSourceDirectories("scene" "scene/culling" "scene/zones" "scene/mixin")
# Handle math
torqueAddSourceDirectories("math" "math/util")
torqueAddSourceDirectories("math" "math/util" "math/simd")
# Handle persistence
set(TORQUE_INCLUDE_DIRECTORIES ${TORQUE_INCLUDE_DIRECTORIES} "persistence/rapidjson")

View file

@ -0,0 +1,112 @@
#include "platform/platform.h"
#include "math/mMath.h"
#include "math/util/frustum.h"
#include <math.h> // Caution!!! Possible platform specific include
#include "math/mMathFn.h"
//################################################################
// AVX 2 Functions
//################################################################
#include <immintrin.h>
void m_point3F_bulk_dot_avx2( const F32* refVector,
const F32* dotPoints,
const U32 numPoints,
const U32 pointStride,
F32* output)
{
__m256 refX = _mm256_set1_ps(refVector[0]);
__m256 refY = _mm256_set1_ps(refVector[1]);
__m256 refZ = _mm256_set1_ps(refVector[2]);
U32 i = 0;
// Process 8 points at a time (AVX2 = 8 floats per 256-bit register)
for (; i + 7 < numPoints; i += 8)
{
// Load x, y, z components with stride
__m256 x = _mm256_set_ps(
dotPoints[(i + 7) * pointStride + 0],
dotPoints[(i + 6) * pointStride + 0],
dotPoints[(i + 5) * pointStride + 0],
dotPoints[(i + 4) * pointStride + 0],
dotPoints[(i + 3) * pointStride + 0],
dotPoints[(i + 2) * pointStride + 0],
dotPoints[(i + 1) * pointStride + 0],
dotPoints[(i + 0) * pointStride + 0]
);
__m256 y = _mm256_set_ps(
dotPoints[(i + 7) * pointStride + 1],
dotPoints[(i + 6) * pointStride + 1],
dotPoints[(i + 5) * pointStride + 1],
dotPoints[(i + 4) * pointStride + 1],
dotPoints[(i + 3) * pointStride + 1],
dotPoints[(i + 2) * pointStride + 1],
dotPoints[(i + 1) * pointStride + 1],
dotPoints[(i + 0) * pointStride + 1]
);
__m256 z = _mm256_set_ps(
dotPoints[(i + 7) * pointStride + 2],
dotPoints[(i + 6) * pointStride + 2],
dotPoints[(i + 5) * pointStride + 2],
dotPoints[(i + 4) * pointStride + 2],
dotPoints[(i + 3) * pointStride + 2],
dotPoints[(i + 2) * pointStride + 2],
dotPoints[(i + 1) * pointStride + 2],
dotPoints[(i + 0) * pointStride + 2]
);
// Multiply and accumulate: x*rx + y*ry + z*rz
__m256 dot = _mm256_mul_ps(x, refX);
dot = _mm256_fmadd_ps(y, refY, dot); // dot += y*refY
dot = _mm256_fmadd_ps(z, refZ, dot); // dot += z*refZ
// Store the results
_mm256_storeu_ps(&output[i], dot);
}
// Handle remaining points
for (; i < numPoints; i++)
{
const F32* pPoint = &dotPoints[i * pointStride];
output[i] = refVector[0] * pPoint[0] +
refVector[1] * pPoint[1] +
refVector[2] * pPoint[2];
}
}
void default_matF_x_matF_AVX2(const F32* A, const F32* B, F32* C)
{
for (int i = 0; i < 4; i++)
{
// Broadcast elements of A row
__m128 a0 = _mm_set1_ps(A[i * 4 + 0]);
__m128 a1 = _mm_set1_ps(A[i * 4 + 1]);
__m128 a2 = _mm_set1_ps(A[i * 4 + 2]);
__m128 a3 = _mm_set1_ps(A[i * 4 + 3]);
// Load columns of B (rows in memory since row-major)
__m128 b0 = _mm_loadu_ps(&B[0 * 4]); // B row 0
__m128 b1 = _mm_loadu_ps(&B[1 * 4]); // B row 1
__m128 b2 = _mm_loadu_ps(&B[2 * 4]); // B row 2
__m128 b3 = _mm_loadu_ps(&B[3 * 4]); // B row 3
// Multiply and sum
__m128 res = _mm_mul_ps(a0, b0);
res = _mm_add_ps(res, _mm_mul_ps(a1, b1));
res = _mm_add_ps(res, _mm_mul_ps(a2, b2));
res = _mm_add_ps(res, _mm_mul_ps(a3, b3));
// Store result row
_mm_storeu_ps(&C[i * 4], res);
}
}
void mInstallLibrary_AVX2()
{
m_point3F_bulk_dot = m_point3F_bulk_dot_avx2;
m_matF_x_matF = default_matF_x_matF_AVX2;
m_matF_x_matF_aligned = default_matF_x_matF_AVX2;
}

View file

@ -0,0 +1,142 @@
#include "platform/platform.h"
#include "math/mMath.h"
#include "math/util/frustum.h"
#include <math.h> // Caution!!! Possible platform specific include
#include "math/mMathFn.h"
//################################################################
// SSE2 Functions - minimum baseline
//################################################################
#include <emmintrin.h>
static void m_point3F_normalize_sse2(float* p)
{
const float val = 1.0f;
// Load vector x, y, z into SSE register (w lane unused)
__m128 vec = _mm_set_ps(0.0f, p[2], p[1], p[0]);
// Compute sum of squares: x*x + y*y + z*z
__m128 sq = _mm_mul_ps(vec, vec);
__m128 sum = _mm_add_ps(sq, _mm_shuffle_ps(sq, sq, _MM_SHUFFLE(2, 1, 0, 3)));
sum = _mm_add_ss(sum, _mm_movehl_ps(sum, sum));
// Extract scalar squared length
float squared;
_mm_store_ss(&squared, sum);
if (squared != 0.0f)
{
// Exact normalization: 1/sqrt(squared)
float factor = 1.0f / std::sqrt(squared);
__m128 factorVec = _mm_set1_ps(factor);
vec = _mm_mul_ps(vec, factorVec);
}
else
{
// Zero-length fallback
vec = _mm_set_ps(0.0f, 1.0f, 0.0f, 0.0f);
}
// Store result back
p[0] = _mm_cvtss_f32(vec);
p[1] = _mm_cvtss_f32(_mm_shuffle_ps(vec, vec, _MM_SHUFFLE(1, 1, 1, 1)));
p[2] = _mm_cvtss_f32(_mm_shuffle_ps(vec, vec, _MM_SHUFFLE(2, 2, 2, 2)));
}
static void m_point3F_normalize_f_sse2(float* p, float val)
{
__m128 vec = _mm_set_ps(0.0f, p[2], p[1], p[0]);
__m128 sq = _mm_mul_ps(vec, vec);
__m128 sum = _mm_add_ps(sq, _mm_shuffle_ps(sq, sq, _MM_SHUFFLE(2, 1, 0, 3)));
sum = _mm_add_ss(sum, _mm_movehl_ps(sum, sum));
float squared;
_mm_store_ss(&squared, sum);
if (squared != 0.0f)
{
float factor = val / std::sqrt(squared); // exact
__m128 factorVec = _mm_set1_ps(factor);
vec = _mm_mul_ps(vec, factorVec);
}
else
{
// Zero-length fallback: use unit vector along z
vec = _mm_set_ps(0.0f, val, 0.0f, 0.0f);
}
p[0] = _mm_cvtss_f32(vec);
p[1] = _mm_cvtss_f32(_mm_shuffle_ps(vec, vec, _MM_SHUFFLE(1, 1, 1, 1)));
p[2] = _mm_cvtss_f32(_mm_shuffle_ps(vec, vec, _MM_SHUFFLE(2, 2, 2, 2)));
}
static void matF_x_point4F_sse2(const float* m, const float* p, float* out)
{
__m128 point = _mm_loadu_ps(p);
__m128 r0 = _mm_loadu_ps(m + 0);
__m128 r1 = _mm_loadu_ps(m + 4);
__m128 r2 = _mm_loadu_ps(m + 8);
__m128 r3 = _mm_loadu_ps(m + 12);
// Multiply rows by vector
__m128 m0 = _mm_mul_ps(r0, point);
__m128 m1 = _mm_mul_ps(r1, point);
__m128 m2 = _mm_mul_ps(r2, point);
__m128 m3 = _mm_mul_ps(r3, point);
// Horizontal add
auto dot4 = [](__m128 v) -> float
{
__m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1));
__m128 sums = _mm_add_ps(v, shuf);
shuf = _mm_movehl_ps(shuf, sums);
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
};
out[0] = dot4(m0);
out[1] = dot4(m1);
out[2] = dot4(m2);
out[3] = dot4(m3);
}
static void m_matF_x_matF_sse2(const float* A, const float* B, float* R)
{
__m128 b0 = _mm_loadu_ps(B + 0);
__m128 b1 = _mm_loadu_ps(B + 4);
__m128 b2 = _mm_loadu_ps(B + 8);
__m128 b3 = _mm_loadu_ps(B + 12);
for (int i = 0; i < 4; i++)
{
__m128 a = _mm_loadu_ps(A + i * 4);
__m128 xxxx = _mm_shuffle_ps(a, a, _MM_SHUFFLE(0, 0, 0, 0));
__m128 yyyy = _mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1));
__m128 zzzz = _mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2));
__m128 wwww = _mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3));
__m128 row =
_mm_add_ps(
_mm_add_ps(_mm_mul_ps(xxxx, b0),
_mm_mul_ps(yyyy, b1)),
_mm_add_ps(_mm_mul_ps(zzzz, b2),
_mm_mul_ps(wwww, b3))
);
_mm_storeu_ps(R + i * 4, row);
}
}
void mInstallLibrary_SSE2()
{
m_point3F_normalize = m_point3F_normalize_sse2;
m_point3F_normalize_f = m_point3F_normalize_f_sse2;
m_matF_x_point4F = matF_x_point4F_sse2;
m_matF_x_matF = m_matF_x_matF_sse2;
m_matF_x_matF_aligned = m_matF_x_matF_sse2;
}

View file

@ -0,0 +1,107 @@
#include "platform/platform.h"
#include "math/mMath.h"
#include "math/util/frustum.h"
#include <math.h> // Caution!!! Possible platform specific include
#include "math/mMathFn.h"
//################################################################
// SSE4.1 Functions
//################################################################
#include <smmintrin.h> // SSE4.1
static void m_point3F_normalize_sse41(float* p)
{
// [x y z 0]
__m128 v = _mm_set_ps(0.0f, p[2], p[1], p[0]);
// dot = x*x + y*y + z*z
__m128 dot = _mm_dp_ps(v, v, 0x71); // xyz, result in x
float lenSq = _mm_cvtss_f32(dot);
if (lenSq != 0.0f)
{
float invLen = 1.0f / sqrtf(lenSq);
__m128 scale = _mm_set1_ps(invLen);
v = _mm_mul_ps(v, scale);
}
else
{
// fallback [0,0,1]
v = _mm_set_ps(0.0f, 1.0f, 0.0f, 0.0f);
}
p[0] = _mm_cvtss_f32(v);
p[1] = _mm_cvtss_f32(_mm_shuffle_ps(v, v, _MM_SHUFFLE(1, 1, 1, 1)));
p[2] = _mm_cvtss_f32(_mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 2, 2, 2)));
}
static void m_point3F_normalize_f_sse41(float* p, float val)
{
// [x y z 0]
__m128 v = _mm_set_ps(0.0f, p[2], p[1], p[0]);
// dot = x*x + y*y + z*z
__m128 dot = _mm_dp_ps(v, v, 0x71); // xyz, result in x
float lenSq = _mm_cvtss_f32(dot);
if (lenSq != 0.0f)
{
float invLen = val / sqrtf(lenSq);
__m128 scale = _mm_set1_ps(invLen);
v = _mm_mul_ps(v, scale);
}
else
{
// fallback [0,0,1]
v = _mm_set_ps(0.0f, 1.0f, 0.0f, 0.0f);
}
p[0] = _mm_cvtss_f32(v);
p[1] = _mm_cvtss_f32(_mm_shuffle_ps(v, v, _MM_SHUFFLE(1, 1, 1, 1)));
p[2] = _mm_cvtss_f32(_mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 2, 2, 2)));
}
void matF_x_point4F_sse41(const float* m, const float* p, float* out)
{
__m128 point = _mm_loadu_ps(p);
__m128 r0 = _mm_loadu_ps(m + 0);
__m128 r1 = _mm_loadu_ps(m + 4);
__m128 r2 = _mm_loadu_ps(m + 8);
__m128 r3 = _mm_loadu_ps(m + 12);
out[0] = _mm_cvtss_f32(_mm_dp_ps(r0, point, 0xF1));
out[1] = _mm_cvtss_f32(_mm_dp_ps(r1, point, 0xF2));
out[2] = _mm_cvtss_f32(_mm_dp_ps(r2, point, 0xF4));
out[3] = _mm_cvtss_f32(_mm_dp_ps(r3, point, 0xF8));
}
static void m_matF_x_matF_sse41(const float* A, const float* B, float* R)
{
__m128 col0 = _mm_set_ps(B[12], B[8], B[4], B[0]);
__m128 col1 = _mm_set_ps(B[13], B[9], B[5], B[1]);
__m128 col2 = _mm_set_ps(B[14], B[10], B[6], B[2]);
__m128 col3 = _mm_set_ps(B[15], B[11], B[7], B[3]);
for (int i = 0; i < 4; i++)
{
__m128 row = _mm_loadu_ps(A + i * 4);
R[i * 4 + 0] = _mm_cvtss_f32(_mm_dp_ps(row, col0, 0xF1));
R[i * 4 + 1] = _mm_cvtss_f32(_mm_dp_ps(row, col1, 0xF1));
R[i * 4 + 2] = _mm_cvtss_f32(_mm_dp_ps(row, col2, 0xF1));
R[i * 4 + 3] = _mm_cvtss_f32(_mm_dp_ps(row, col3, 0xF1));
}
}
void mInstallLibrary_SSE41()
{
m_point3F_normalize = m_point3F_normalize_sse41;
m_point3F_normalize_f = m_point3F_normalize_f_sse41;
m_matF_x_point4F = matF_x_point4F_sse41;
m_matF_x_matF = m_matF_x_matF_sse41;
m_matF_x_matF_aligned = m_matF_x_matF_sse41;
}

View file

@ -27,6 +27,9 @@
#include "math/mMath.h"
extern void mInstallLibrary_SSE2();
extern void mInstallLibrary_SSE41();
extern void mInstallLibrary_AVX2();
extern void mInstallLibrary_C();
extern void mInstallLibrary_ASM();
@ -98,22 +101,19 @@ void Math::init(U32 properties)
Con::printf(" Installing Standard C extensions");
mInstallLibrary_C();
Con::printf(" Installing Assembly extensions");
mInstallLibrary_ASM();
if (properties & CPU_PROP_FPU)
{
Con::printf(" Installing FPU extensions");
}
if (properties & CPU_PROP_MMX)
{
Con::printf(" Installing MMX extensions");
}
if (properties & CPU_PROP_SSE)
{
Con::printf(" Installing SSE extensions");
if (properties & CPU_PROP_SSE2)
mInstallLibrary_SSE2();
if(properties & CPU_PROP_SSE4_1)
mInstallLibrary_SSE41();
}
if (properties & CPU_PROP_AVX2)
{
Con::printf(" Installing AVX2 extensions");
mInstallLibrary_AVX2();
}
Con::printf(" ");