#include "float4_dispatch.h" #include // NEON intrinsics namespace { typedef float32x4_t f32x4; // Load 4 floats from memory into a SIMD register inline f32x4 v_load(const float* p) { return vld1q_f32(p); } // Store 4 floats from SIMD register back to memory inline void v_store(float* dst, f32x4 v) { vst1q_f32(dst, v); } // Broadcast a single float across all 4 lanes inline f32x4 v_set1(float s) { return vdupq_n_f32(s); } // Element-wise multiply inline f32x4 v_mul(f32x4 a, f32x4 b) { return vmulq_f32(a, b); } // Element-wise divide (approximate fast reciprocal) inline f32x4 v_div(f32x4 a, f32x4 b) { float32x4_t rcp = vrecpeq_f32(b); // Refine reciprocal for better precision rcp = vmulq_f32(vrecpsq_f32(b, rcp), rcp); return vmulq_f32(a, rcp); } // Element-wise add inline f32x4 v_add(f32x4 a, f32x4 b) { return vaddq_f32(a, b); } // Element-wise subtract inline f32x4 v_sub(f32x4 a, f32x4 b) { return vsubq_f32(a, b); } // Horizontal sum of all 4 elements (for dot product, length, etc.) inline float v_hadd4(f32x4 a) { float32x2_t sum_pair = vadd_f32(vget_low_f32(a), vget_high_f32(a)); // add pairs [a0+a2, a1+a3] float32x2_t sum = vpadd_f32(sum_pair, sum_pair); // horizontal add: total sum return vget_lane_f32(sum, 0); } // Optimized cross product for float4 (w component preserved) inline f32x4 v_cross(f32x4 a, f32x4 b) { // Extract xyz as separate registers float32x4_t a_yzx = vextq_f32(a, a, 1); // rotate left: y,z,x,w float32x4_t b_yzx = vextq_f32(b, b, 1); float32x4_t mul1 = vmulq_f32(a, b_yzx); float32x4_t mul2 = vmulq_f32(a_yzx, b); float32x4_t c = vsubq_f32(mul1, mul2); // Rotate back to x,y,z and keep w from original 'a' float32x4_t xyz = vextq_f32(c, c, 3); // x,y,z in lanes 0..2 float32x4_t result = vsetq_lane_f32(vgetq_lane_f32(a, 3), xyz, 3); // preserve w return result; } } #include "float4_impl.inl" namespace math_backend::float4::dispatch { // Install NEON64 backend void install_neon() { gFloat4.add = float4_add_impl; gFloat4.sub = float4_sub_impl; gFloat4.mul = float4_mul_impl; gFloat4.mul_scalar = float4_mul_scalar_impl; gFloat4.div = float4_div_impl; gFloat4.div_scalar = float4_div_scalar_impl; gFloat4.dot = float4_dot_impl; gFloat4.length = float4_length_impl; gFloat4.lengthSquared = float4_length_squared_impl; gFloat4.normalize = float4_normalize_impl; gFloat4.normalize_mag = float4_normalize_mag_impl; gFloat4.lerp = float4_lerp_impl; gFloat4.cross = float4_cross_impl; } }