feat(vec): vec_fma

This commit is contained in:
thetek 2023-04-25 20:00:03 +02:00
parent ecc2e0245e
commit 6677629bfa
3 changed files with 46 additions and 4 deletions

View File

@ -17,16 +17,19 @@
MACRO(f32); MACRO(f64)
#define DECLARE_VEC_ADD(TYPE) void vec_add_##TYPE (usize n, TYPE a[restrict n], const TYPE b[restrict n])
#define DECLARE_VEC_SUB(TYPE) void vec_sub_##TYPE (usize n, TYPE a[restrict n], const TYPE b[restrict n])
#define DECLARE_VEC_FMA(TYPE) void vec_fma_##TYPE (usize n, TYPE a[restrict n], const TYPE b[restrict n], const TYPE c[restrict n])
// exported macros
#define vec_add(n, a, b) GET_FUNCTION_NAME(vec_add, a)(n, a, b) /* vectorized addition - void vec_add<T>(usize n, T a[restrict n], const T b[restrict n]) */
#define vec_sub(n, a, b) GET_FUNCTION_NAME(vec_sub, a)(n, a, b) /* vectorized subtraction - void vec_sub<T>(usize n, T a[restrict n], const T b[restrict n]) */
#define vec_fma(n, a, b, c) GET_FUNCTION_NAME(vec_fma, a)(n, a, b, c) /* vectorized fused multiply-add - void vec_fma<T>(usize n, T a[restrict n], const T b[restrict n], const T c[restrict n]) */
// function declarations
DECLARE_FUNCTIONS(DECLARE_VEC_ADD);
DECLARE_FUNCTIONS(DECLARE_VEC_SUB);
DECLARE_FUNCTIONS(DECLARE_VEC_FMA);
// undefine helper macros
@ -34,5 +37,6 @@ DECLARE_FUNCTIONS(DECLARE_VEC_SUB);
#undef __DECLARE_FUNCTIONS
#undef __DECLARE_VEC_ADD
#undef __DECLARE_VEC_SUB
#undef __DECLARE_VEC_FMA
#endif // VEC_H_

View File

@ -8,8 +8,9 @@
/**
* vectorized addition. each element of vector `b` will be added to its
* counterpart element in vector `a` (i.e. `a[i] += b[i]`). the result will be
* stored in `a`
* stored in `a`.
*
* @param n: number of elements to add
* @param a: first vector, result destination
* @param b: second vector
*
@ -28,8 +29,9 @@
/**
* vectorized subtraction. each element of vector `b` will be subtracted from
* its counterpart element in vector `a` (i.e. `a[i] -= b[i]`). the result will
* be stored in `a`
* be stored in `a`.
*
* @param n: number of elements to subtract
* @param a: first vector, result destination
* @param b: second vector
*
@ -45,5 +47,30 @@
a[i] -= b[i]; \
}
/**
* vectorized fused multiply-add. each element of vector `a` will be multiplied
* by its counterpart element in vector `b`, then the corresponding element in
* vector `c` will be added to a (i.e. `a[i] = a[i] * b[i] + c[i]`). the result
* will be stored in `a`.
*
* @param n: number of elements to subtract
* @param a: first vector, result destination
* @param b: second vector (for multiplication)
* @param c: third vector (final summand)
*
* @usage either through `vec_fma_<type> (e.g. `vec_fma_u32`), or through the
* generic macro `vec_fma`
*/
#define IMPL_VEC_FMA(TYPE) \
void \
vec_fma_##TYPE (usize n, TYPE a[restrict n], const TYPE b[restrict n], \
const TYPE c[restrict n]) \
{ \
usize i; \
for (i = 0; i < n; i++) \
a[i] = a[i] * b[i] + c[i]; \
}
IMPLEMENT_FUNCTIONS(IMPL_VEC_ADD)
IMPLEMENT_FUNCTIONS(IMPL_VEC_SUB)
IMPLEMENT_FUNCTIONS(IMPL_VEC_FMA)

View File

@ -110,7 +110,7 @@ test_results_t
test_vec (void)
{
test_group_t group;
i32 a[42], b[42], a_orig[42];
i32 a[42], b[42], c[42], a_orig[42];
bool all_correct;
usize i;
@ -119,6 +119,7 @@ test_vec (void)
for (i = 0; i < 42; i++) {
a[i] = a_orig[i] = (i32) rand_range (0, INT16_MAX);
b[i] = (i32) rand_range (0, INT16_MAX);
c[i] = (i32) rand_range (0, INT16_MAX);
}
vec_add (42, a, b);
@ -131,13 +132,23 @@ test_vec (void)
for (i = 0; i < 42; i++)
a[i] = a_orig[i];
vec_sub(42, a, b);
vec_sub (42, a, b);
all_correct = true;
for (i = 0; i < 42; i++)
all_correct &= (a[i] == a_orig[i] - b[i]);
test_add (&group, test_assert (all_correct && "for all i: a[i] = a_orig[i] - b[i]"), "vec_sub");
for (i = 0; i < 42; i++)
a[i] = a_orig[i];
vec_fma (42, a, b, c);
all_correct = true;
for (i = 0; i < 42; i++)
all_correct &= (a[i] == a_orig[i] * b[i] + c[i]);
test_add (&group, test_assert (all_correct && "for all i: a[i] = a_orig[i] * b[i] + c[i]"), "vec_fma");
return test_group_get_results (&group);
}