Skip to content

Conta i bit zero iniziali per ogni elemento del vettore AVX2, emula _mm256_lzcnt_epi32

Fai attenzione perché in questo tutorial troverai la risposta che stai cercando.

Soluzione:

float rappresenta i numeri in un formato esponenziale, quindi la conversione int->FP ci dà la posizione del bit più alto codificato nel campo dell'esponente.

Vogliamo int->float con la magnitudine arrotondata verso il basso (troncando il valore verso 0), non l'arrotondamento predefinito del più vicino. Questo potrebbe arrotondare per eccesso e rendere 0x3FFFFFFF come se fosse 0x40000000. Se si eseguono molte conversioni senza fare calcoli FP, si può impostare la modalità di arrotondamento nell'MXCSR 1 su troncamento e poi reimpostarla quando si è finito.

Altrimenti si può usare v & ~(v>>8) per mantenere gli 8 bit più significativi e azzerare alcuni o tutti i bit inferiori, compreso un bit potenzialmente impostato 8 sotto l'MSB. Questo è sufficiente per garantire che tutte le modalità di arrotondamento non arrotondino mai alla prossima potenza di due. Mantiene sempre gli 8 MSB perché v>>8 sposta 8 zeri, quindi invertiti sono 8 uno. Nelle posizioni più basse dei bit, ovunque si trovi l'MSB, vengono spostati 8 zeri dalle posizioni più alte, quindi non cancellerà mai il bit più significativo di un intero. A seconda di come si allineano i bit sotto l'MSB, potrebbe cancellare o meno altri bit sotto gli 8 più significativi.

Dopo la conversione, utilizziamo uno shift intero sul modello di bit per portare l'esponente (e il bit di segno) in basso e annullare il bias con una sottrazione a saturazione. Utilizziamo min per impostare il risultato a 32 se nessun bit è stato impostato nell'ingresso originale a 32 bit.

__m256i avx2_lzcnt_epi32 (__m256i v) {
    // prevent value from being rounded up to the next power of two
    v = _mm256_andnot_si256(_mm256_srli_epi32(v, 8), v); // keep 8 MSB

    v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float
    v = _mm256_srli_epi32(v, 23); // shift down the exponent
    v = _mm256_subs_epu16(_mm256_set1_epi32(158), v); // undo bias
    v = _mm256_min_epi16(v, _mm256_set1_epi32(32)); // clamp at 32

    return v;
}

Nota 1: la conversione fp->int è disponibile con il troncamento (cvtt), ma la conversione int->fp è disponibile solo con l'arrotondamento predefinito (soggetto a MXCSR).

AVX512F introduce le opzioni di arrotondamento per i vettori a 512 bit, risolvendo così il problema, __m512 _mm512_cvt_roundepi32_ps( __m512i a, int r);. Ma tutte le CPU con AVX512F supportano anche AVX512CD, quindi si potrebbe usare semplicemente _mm512_lzcnt_epi32. E con AVX512VL, _mm256_lzcnt_epi32

La risposta di @aqrit sembra un uso più intelligente dei bithack FP. La mia risposta qui sotto si basa sul primo posto in cui ho cercato un bithack, che era vecchio e rivolto a scalari, quindi non cercava di evitare double (che è più ampio di int32 e quindi un problema per SIMD).

Utilizza HW firmato int->float e la saturazione delle sottrazioni di interi per gestire l'MSB impostato (float negativo), invece di riempire i bit in una mantissa per uint->double manuale. Se si può impostare MXCSR per l'arrotondamento per difetto di molti di questi valori _mm256_lzcnt_epi32è ancora più efficiente.


https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogIEEE64Float suggerisce di inserire gli interi nella mantissa di un grande doublee poi sottrarre per ottenere dall'hardware della FPU un valore normalizzato di double. (Penso che questo pezzo di magia stia facendo uint32_t -> double con la tecnica che @Mysticial spiega in Come eseguire in modo efficiente conversioni double/int64 con SSE/AVX? (che funziona per uint64_t fino a 2 52-1)

Quindi, prendere i bit dell'esponente dell'elemento double e annullare la polarizzazione.

Penso che integer log2 sia la stessa cosa di lzcnt, ma potrebbe esserci un off-by-1 alle potenze di 2.

La pagina del bithack di Standford Graphics elenca altri bithack senza rami che si potrebbero usare e che probabilmente sarebbero ancora migliori dello scalare 8x lzcnt.

Se sapete che i vostri numeri sono sempre piccoli (come meno di 2^23), potreste farlo con float ed evitare la divisione e la miscelazione.

  int v; // 32-bit integer to find the log base 2 of
  int r; // result of log_2(v) goes here
  union { unsigned int u[2]; double d; } t; // temp

  t.u[__FLOAT_WORD_ORDER==LITTLE_ENDIAN] = 0x43300000;
  t.u[__FLOAT_WORD_ORDER!=LITTLE_ENDIAN] = v;
  t.d -= 4503599627370496.0;
  r = (t.u[__FLOAT_WORD_ORDER==LITTLE_ENDIAN] >> 20) - 0x3FF;

Il codice sopra riportato carica un doppio a 64 bit (in virgola mobile IEEE-754) con un intero a 32 bit (senza bit di padding) memorizzando l'intero nella mantissa mentre l'esponente viene impostato a 252. Da questo doppio appena coniato, 252 (espresso come doppio) viene sottratto, il che imposta l'esponente risultante alla base log 2 del valore di ingresso, v. Tutto ciò che rimane è spostare i bit dell'esponente in posizione (20 bit a destra) e sottrarre il bias, 0x3FF (che è 1023 decimale).

Per eseguire questa operazione con AVX2, miscelare e miscelare le metà dispari e pari con set1_epi32(0x43300000) e _mm256_castps_pd per ottenere un __m256d. E dopo aver sottratto, _mm256_castpd_si256 e spostare / fondere le metà bassa/alta al loro posto, quindi mascherare per ottenere gli esponenti.

L'esecuzione di operazioni intere su schemi di bit FP è molto efficiente con AVX2, solo 1 ciclo di latenza aggiuntiva per un ritardo di bypass quando si eseguono spostamenti interi sull'uscita di un'istruzione matematica FP.

(TODO: scrivetelo con gli intrinseci del C++, modificatelo o qualcun altro potrebbe postarlo come risposta).


Non sono sicuro che si possa fare qualcosa con int -> doubleconversione e poi leggere il campo dell'esponente. I numeri negativi non hanno zeri iniziali e i numeri positivi forniscono un esponente che dipende dalla grandezza.

Se lo si volesse, si dovrebbe procedere con una corsia di 128 bit alla volta, mescolando per alimentare xmm -> ymm impacchettato int32_t -> impacchettato double conversione.

La domanda è anche etichettata AVXma non ci sono istruzioni per l'elaborazione dei numeri interi in AVX, il che significa che bisogna ripiegare su SSE su piattaforme che supportano AVX ma non AVX2. Di seguito mostro una versione esaustivamente testata, ma un po' pedestre. L'idea di base è la stessa delle altre risposte, in quanto il numero di zeri iniziali è determinato dalla normalizzazione in virgola mobile che avviene durante la conversione da numeri interi a virgola mobile. L'esponente del risultato ha una corrispondenza uno-a-uno con il numero di zeri iniziali, tranne per il fatto che il risultato è sbagliato nel caso di un argomento pari a zero. Concettualmente:

clz (a) = (158 - (float_as_uint32 (uint32_to_float_rz (a)) >> 23)) + (a == 0)

dove float_as_uint32() è un cast di reinterpretazione e uint32_to_float_rz() è una conversione da intero senza segno a virgola mobile. con troncamento. Una conversione normale, con arrotondamento, potrebbe far salire il risultato della conversione alla successiva potenza di due, con conseguente conteggio errato dei bit zero iniziali.

SSE non fornisce la conversione da intero troncato a virgola mobile come singola istruzione, né la conversione da interi senza segno. Questa funzionalità deve essere emulata. Non è necessario che l'emulazione sia esatta, purché non cambi la grandezza del risultato della conversione. La parte di troncamento è gestita dall'opzione inverti - spostamento a destra - andn della risposta di aqrit. Per utilizzare la conversione firmata, tagliamo il numero a metà prima della conversione, quindi raddoppiamo e incrementiamo dopo la conversione:

float approximate_uint32_to_float_rz (uint32_t a)
{
    float r = (float)(int)((a >> 1) & ~(a >> 2));
    return r + r + 1.0f;
}

Questo approccio si traduce in SSE intrinseca in sse_clz() di seguito.

#include 
#include 
#include 
#include 
#include "immintrin.h"

/* compute count of leading zero bits using floating-point normalization.

   clz(a) = (158 - (float_as_uint32 (uint32_to_float_rz (a)) >> 23)) + (a == 0)

   The problematic part here is uint32_to_float_rz(). SSE does not offer
   conversion of unsigned integers, and no rounding modes in integer to
   floating-point conversion. Since all we need is an approximate version
   that preserves order of magnitude:

   float approximate_uint32_to_float_rz (uint32_t a)
   {
      float r = (float)(int)((a >> 1) & ~(a >> 2));
      return r + r + 1.0f;
   }
*/  
__m128i sse_clz (__m128i a) 
{
    __m128 fp1 = _mm_set_ps1 (1.0f);
    __m128i zero = _mm_set1_epi32 (0);
    __m128i i158 = _mm_set1_epi32 (158);
    __m128i iszero = _mm_cmpeq_epi32 (a, zero);
    __m128i lsr1 = _mm_srli_epi32 (a, 1);
    __m128i lsr2 = _mm_srli_epi32 (a, 2);
    __m128i atrunc = _mm_andnot_si128 (lsr2, lsr1);
    __m128 atruncf = _mm_cvtepi32_ps (atrunc);
    __m128 atruncf2 = _mm_add_ps (atruncf, atruncf);
    __m128 conv = _mm_add_ps (atruncf2, fp1);
    __m128i convi = _mm_castps_si128 (conv);
    __m128i lsr23 = _mm_srli_epi32 (convi, 23);
    __m128i res = _mm_sub_epi32 (i158, lsr23);
    return _mm_sub_epi32 (res, iszero);
}

/* Portable reference implementation of 32-bit count of leading zeros */    
int clz32 (uint32_t a)
{
    uint32_t r = 32;
    if (a >= 0x00010000) { a >>= 16; r -= 16; }
    if (a >= 0x00000100) { a >>=  8; r -=  8; }
    if (a >= 0x00000010) { a >>=  4; r -=  4; }
    if (a >= 0x00000004) { a >>=  2; r -=  2; }
    r -= a - (a & (a >> 1));
    return r;
}

/* Test floating-point based count leading zeros exhaustively */
int main (void)
{
    __m128i res;
    uint32_t resi[4], refi[4];
    uint32_t count = 0;
    do {
        refi[0] = clz32 (count);
        refi[1] = clz32 (count + 1);
        refi[2] = clz32 (count + 2);
        refi[3] = clz32 (count + 3);
        res = sse_clz (_mm_set_epi32 (count + 3, count + 2, count + 1, count));
        memcpy (resi, &res, sizeof resi);
        if ((resi[0] != refi[0]) || (resi[1] != refi[1]) ||
            (resi[2] != refi[2]) || (resi[3] != refi[3])) {
            printf ("error @ %08x %08x %08x %08xn",
                    count, count+1, count+2, count+3);
            return EXIT_FAILURE;
        }
        count += 4;
    } while (count);
    return EXIT_SUCCESS;
}

Commenti e valutazioni

Ricorda che puoi interpretare se hai trovato la risposta.



Utilizzate il nostro motore di ricerca

Ricerca
Generic filters

Lascia un commento

Il tuo indirizzo email non sarà pubblicato.