Nothing Special   »   [go: up one dir, main page]

Skip to content

Commit

Permalink
Improve: Faster binary ops on x86
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 14, 2024
1 parent f459c68 commit cfc0c44
Showing 1 changed file with 150 additions and 36 deletions.
186 changes: 150 additions & 36 deletions include/simsimd/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,54 +165,168 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_sve(simsimd_b8_t const* a, simsimd_b8_t c

SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words,
simsimd_distance_t* result) {
__m512i differences_vec = _mm512_setzero_si512();
__m512i a_vec, b_vec;

simsimd_hamming_b8_ice_cycle:
if (n_words < 64) {
simsimd_size_t xor_count;
// It's harder to squeeze out performance from tiny representations, so we unroll the loops for binary metrics.
if (n_words <= 64) { // Up to 512 bits.
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words);
a_vec = _mm512_maskz_loadu_epi8(mask, a);
b_vec = _mm512_maskz_loadu_epi8(mask, b);
n_words = 0;
__m512i a_vec = _mm512_maskz_loadu_epi8(mask, a);
__m512i b_vec = _mm512_maskz_loadu_epi8(mask, b);
__m512i xor_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a_vec, b_vec));
xor_count = _mm512_reduce_add_epi64(xor_count_vec);
} else if (n_words <= 128) { // Up to 1024 bits.
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 64);
__m512i a1_vec = _mm512_loadu_epi8(a);
__m512i b1_vec = _mm512_loadu_epi8(b);
__m512i a2_vec = _mm512_maskz_loadu_epi8(mask, a + 64);
__m512i b2_vec = _mm512_maskz_loadu_epi8(mask, b + 64);
__m512i xor1_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a1_vec, b1_vec));
__m512i xor2_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a2_vec, b2_vec));
xor_count = _mm512_reduce_add_epi64(_mm512_add_epi64(xor2_count_vec, xor1_count_vec));
} else if (n_words <= 196) { // Up to 1568 bits.
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 128);
__m512i a1_vec = _mm512_loadu_epi8(a);
__m512i b1_vec = _mm512_loadu_epi8(b);
__m512i a2_vec = _mm512_loadu_epi8(a + 64);
__m512i b2_vec = _mm512_loadu_epi8(b + 64);
__m512i a3_vec = _mm512_maskz_loadu_epi8(mask, a + 128);
__m512i b3_vec = _mm512_maskz_loadu_epi8(mask, b + 128);
__m512i xor1_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a1_vec, b1_vec));
__m512i xor2_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a2_vec, b2_vec));
__m512i xor3_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a3_vec, b3_vec));
xor_count =
_mm512_reduce_add_epi64(_mm512_add_epi64(xor3_count_vec, _mm512_add_epi64(xor2_count_vec, xor1_count_vec)));
} else if (n_words <= 256) { // Up to 2048 bits.
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 192);
__m512i a1_vec = _mm512_loadu_epi8(a);
__m512i b1_vec = _mm512_loadu_epi8(b);
__m512i a2_vec = _mm512_loadu_epi8(a + 64);
__m512i b2_vec = _mm512_loadu_epi8(b + 64);
__m512i a3_vec = _mm512_loadu_epi8(a + 128);
__m512i b3_vec = _mm512_loadu_epi8(b + 128);
__m512i a4_vec = _mm512_maskz_loadu_epi8(mask, a + 192);
__m512i b4_vec = _mm512_maskz_loadu_epi8(mask, b + 192);
__m512i xor1_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a1_vec, b1_vec));
__m512i xor2_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a2_vec, b2_vec));
__m512i xor3_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a3_vec, b3_vec));
__m512i xor4_count_vec = _mm512_popcnt_epi64(_mm512_xor_si512(a4_vec, b4_vec));
xor_count = _mm512_reduce_add_epi64(_mm512_add_epi64(_mm512_add_epi64(xor4_count_vec, xor3_count_vec),
_mm512_add_epi64(xor2_count_vec, xor1_count_vec)));
} else {
a_vec = _mm512_loadu_epi8(a);
b_vec = _mm512_loadu_epi8(b);
a += 64, b += 64, n_words -= 64;
__m512i xor_count_vec = _mm512_setzero_si512();
__m512i a_vec, b_vec;

simsimd_hamming_b8_ice_cycle:
if (n_words < 64) {
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words);
a_vec = _mm512_maskz_loadu_epi8(mask, a);
b_vec = _mm512_maskz_loadu_epi8(mask, b);
n_words = 0;
} else {
a_vec = _mm512_loadu_epi8(a);
b_vec = _mm512_loadu_epi8(b);
a += 64, b += 64, n_words -= 64;
}
__m512i xor_vec = _mm512_xor_si512(a_vec, b_vec);
xor_count_vec = _mm512_add_epi64(xor_count_vec, _mm512_popcnt_epi64(xor_vec));
if (n_words)
goto simsimd_hamming_b8_ice_cycle;

xor_count = _mm512_reduce_add_epi64(xor_count_vec);
}
__m512i xor_vec = _mm512_xor_si512(a_vec, b_vec);
differences_vec = _mm512_add_epi64(differences_vec, _mm512_popcnt_epi64(xor_vec));
if (n_words)
goto simsimd_hamming_b8_ice_cycle;

simsimd_size_t differences = _mm512_reduce_add_epi64(differences_vec);
*result = differences;
*result = xor_count;
}

SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words,
simsimd_distance_t* result) {
__m512i intersection_vec = _mm512_setzero_si512(), union_vec = _mm512_setzero_si512();
__m512i a_vec, b_vec;

simsimd_jaccard_b8_ice_cycle:
if (n_words < 64) {
simsimd_size_t intersection = 0, union_ = 0;
// It's harder to squeeze out performance from tiny representations, so we unroll the loops for binary metrics.
if (n_words <= 64) { // Up to 512 bits.
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words);
a_vec = _mm512_maskz_loadu_epi8(mask, a);
b_vec = _mm512_maskz_loadu_epi8(mask, b);
n_words = 0;
__m512i a_vec = _mm512_maskz_loadu_epi8(mask, a);
__m512i b_vec = _mm512_maskz_loadu_epi8(mask, b);
__m512i and_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a_vec, b_vec));
__m512i or_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a_vec, b_vec));
intersection = _mm512_reduce_add_epi64(and_count_vec);
union_ = _mm512_reduce_add_epi64(or_count_vec);
} else if (n_words <= 128) { // Up to 1024 bits.
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 64);
__m512i a1_vec = _mm512_loadu_epi8(a);
__m512i b1_vec = _mm512_loadu_epi8(b);
__m512i a2_vec = _mm512_maskz_loadu_epi8(mask, a + 64);
__m512i b2_vec = _mm512_maskz_loadu_epi8(mask, b + 64);
__m512i and1_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a1_vec, b1_vec));
__m512i or1_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a1_vec, b1_vec));
__m512i and2_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a2_vec, b2_vec));
__m512i or2_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a2_vec, b2_vec));
intersection = _mm512_reduce_add_epi64(_mm512_add_epi64(and2_count_vec, and1_count_vec));
union_ = _mm512_reduce_add_epi64(_mm512_add_epi64(or2_count_vec, or1_count_vec));
} else if (n_words <= 196) { // Up to 1568 bits.
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 128);
__m512i a1_vec = _mm512_loadu_epi8(a);
__m512i b1_vec = _mm512_loadu_epi8(b);
__m512i a2_vec = _mm512_loadu_epi8(a + 64);
__m512i b2_vec = _mm512_loadu_epi8(b + 64);
__m512i a3_vec = _mm512_maskz_loadu_epi8(mask, a + 128);
__m512i b3_vec = _mm512_maskz_loadu_epi8(mask, b + 128);
__m512i and1_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a1_vec, b1_vec));
__m512i or1_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a1_vec, b1_vec));
__m512i and2_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a2_vec, b2_vec));
__m512i or2_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a2_vec, b2_vec));
__m512i and3_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a3_vec, b3_vec));
__m512i or3_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a3_vec, b3_vec));
intersection =
_mm512_reduce_add_epi64(_mm512_add_epi64(and3_count_vec, _mm512_add_epi64(and2_count_vec, and1_count_vec)));
union_ =
_mm512_reduce_add_epi64(_mm512_add_epi64(or3_count_vec, _mm512_add_epi64(or2_count_vec, or1_count_vec)));
} else if (n_words <= 256) { // Up to 2048 bits.
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words - 192);
__m512i a1_vec = _mm512_loadu_epi8(a);
__m512i b1_vec = _mm512_loadu_epi8(b);
__m512i a2_vec = _mm512_loadu_epi8(a + 64);
__m512i b2_vec = _mm512_loadu_epi8(b + 64);
__m512i a3_vec = _mm512_loadu_epi8(a + 128);
__m512i b3_vec = _mm512_loadu_epi8(b + 128);
__m512i a4_vec = _mm512_maskz_loadu_epi8(mask, a + 192);
__m512i b4_vec = _mm512_maskz_loadu_epi8(mask, b + 192);
__m512i and1_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a1_vec, b1_vec));
__m512i or1_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a1_vec, b1_vec));
__m512i and2_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a2_vec, b2_vec));
__m512i or2_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a2_vec, b2_vec));
__m512i and3_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a3_vec, b3_vec));
__m512i or3_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a3_vec, b3_vec));
__m512i and4_count_vec = _mm512_popcnt_epi64(_mm512_and_si512(a4_vec, b4_vec));
__m512i or4_count_vec = _mm512_popcnt_epi64(_mm512_or_si512(a4_vec, b4_vec));
intersection = _mm512_reduce_add_epi64(_mm512_add_epi64(_mm512_add_epi64(and4_count_vec, and3_count_vec),
_mm512_add_epi64(and2_count_vec, and1_count_vec)));
union_ = _mm512_reduce_add_epi64(_mm512_add_epi64(_mm512_add_epi64(or4_count_vec, or3_count_vec),
_mm512_add_epi64(or2_count_vec, or1_count_vec)));
} else {
a_vec = _mm512_loadu_epi8(a);
b_vec = _mm512_loadu_epi8(b);
a += 64, b += 64, n_words -= 64;
__m512i and_count_vec = _mm512_setzero_si512(), or_count_vec = _mm512_setzero_si512();
__m512i a_vec, b_vec;

simsimd_jaccard_b8_ice_cycle:
if (n_words < 64) {
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words);
a_vec = _mm512_maskz_loadu_epi8(mask, a);
b_vec = _mm512_maskz_loadu_epi8(mask, b);
n_words = 0;
} else {
a_vec = _mm512_loadu_epi8(a);
b_vec = _mm512_loadu_epi8(b);
a += 64, b += 64, n_words -= 64;
}
__m512i and_vec = _mm512_and_si512(a_vec, b_vec);
__m512i or_vec = _mm512_or_si512(a_vec, b_vec);
and_count_vec = _mm512_add_epi64(and_count_vec, _mm512_popcnt_epi64(and_vec));
or_count_vec = _mm512_add_epi64(or_count_vec, _mm512_popcnt_epi64(or_vec));
if (n_words)
goto simsimd_jaccard_b8_ice_cycle;

intersection = _mm512_reduce_add_epi64(and_count_vec);
union_ = _mm512_reduce_add_epi64(or_count_vec);
}
__m512i and_vec = _mm512_and_si512(a_vec, b_vec);
__m512i or_vec = _mm512_or_si512(a_vec, b_vec);
intersection_vec = _mm512_add_epi64(intersection_vec, _mm512_popcnt_epi64(and_vec));
union_vec = _mm512_add_epi64(union_vec, _mm512_popcnt_epi64(or_vec));
if (n_words)
goto simsimd_jaccard_b8_ice_cycle;

simsimd_size_t intersection = _mm512_reduce_add_epi64(intersection_vec),
union_ = _mm512_reduce_add_epi64(union_vec);
*result = (union_ != 0) ? 1 - (simsimd_f64_t)intersection / (simsimd_f64_t)union_ : 1;
}

Expand Down

0 comments on commit cfc0c44

Please sign in to comment.