Python を使って bfloat16
blog
2019-4-15 2:28 JST

TensorFlow で学習させて、NumPy 等を使い学習済みのパラメタをデータを bfloat16 の Python の list に落とすことに成功しました。そのデータをもう一度 float に戻した上で NumPy で計算するとうまく予測してくれそう(多くの人がそう信じていろんなところで 実験しているはず)なこともわかりました。

次のステップは NumPy からも離れて bfloat16 をPython で直接扱うようにしてみましょう。その為に bfloat16 の足し算と掛け算をするプログラムを組みます。

掛算のプログラム

浮動小数点数の計算なので足し算より掛け算の方が楽です。Python のソースコードを次に掲げます。まずは正しく動くことが重要なので最適化されていません(いいわけ)。誰か賢い人、より最適化してみてください。あるいは C でオフロードしてみてください。そのまま組み込みシステムの世界で使えるようになるかもしれません。

bit16 というタイプヒントを使っていますが、気にせず行きましょう。最終的には Polyphony というコンパイラを使って FPGA で動くようにするための布石です。タイプヒントはあくまでヒントなのでPython 実行時は見事に無視してくれます。

mul.py
def mul(x:bit16, w:bit16):
    if x == 0:
        return 0
    if w == 0:
        return 0
    x_e = (x >> 7) & 0xFF 
    w_e = (w >> 7) & 0xFF
    e:bit8 = (x_e - 127) + (w_e - 127) + 127
    x_n = (x & 0x7F)
    w_n = (w & 0x7F)
    new_n:bit16 = ((x_n | 0x80) * (w_n | 0x80))
    if new_n & 0x8000:
        new_n >>= 8
        e += 1
    else:
        new_n >>= 7
    #print(x_n, w_n, new_n)
    new_n &= 0x7F
    s = (x & 0x8000) ^ (w & 0x8000)
    #print('s', s, e, new_n)
    
    x_w = s | (e << 7) | (new_n)
    #print(f'result:{x_w:2x}, {w_n} {x_n} {new_n}')
    return x_w

足し算のプログラム

桁合わせがあるので掛算よりちょっと複雑です。浮動小数点数の計算は精度があまりにも出なくなった場合ゼロにまるめたり、最後のビット等を微妙に切り上げたりというテクニックを使って、精度をなるべく落とさないようにするのが定石です。が!ここでは一切そんなことはしていません。動けばよかろう!(きっぱり)なのです(ここではね)。

add.py
def sub_add(x_sign:bit, x:bit8, b_sign:bit, b:bit8, e:bit8):
    #print(f'sub_add {x_sign} {x:08b}, {b_sign} {b:08b}')
    #print('sub_add', x_sign, x, b_sign, b, e)

    if (x_sign == 0 and b_sign == 0) or (x_sign == 1 and b_sign == 1):
        rv_n:bit9 = (x + b)
        add_e = 1 if rv_n & 0x100 else 0
        if add_e:
            rv_n >>= 1

        return (0x8000 if x_sign else 0x0000) | ((e + add_e) << 7) | (rv_n & 0x7F)
    else:
        if x < b:
            x_sign, b_sign = b_sign, x_sign
            x, b = b, x

        #print('sub_add', x_sign, x, b_sign, b, e, x - b)

        rv_n = x - b
        rv_sign = x_sign
        for i in range(0, 7): 
            #print(i, rv_n)
            if rv_n & 0x80:
                return (0x8000 if rv_sign else 0x0000) | ((e - i) << 7) | (rv_n & 0x7F)
            rv_n <<= 1

        return 0

def add(x:bit16, b:bit16):
    if x == 0:
        return b
    if b == 0:
        return x
    x_sign = 1 if x & 0x8000 else 0
    b_sign = 1 if b & 0x8000 else 0
    x_e = (x >> 7) & 0xFF 
    b_e = (b >> 7) & 0xFF

    if x_e < b_e:
        x, b = b, x
        x_e, b_e = b_e, x_e
        x_sign, b_sign = b_sign, x_sign

    x_n = (x & 0x7F)
    b_n = (b & 0x7F)

    d = x_e - b_e
    e = x_e
    if d > 8:
        return x

    new_n:bit8 = 0
    if d == 0:
        rv = sub_add(x_sign, 0x80 | x_n, b_sign, 0x80 | b_n, e)
    else:
        new_b_n = ((0x80 | b_n) >> d) + ((b_n >> (d-1)) & 1)
        #print('d:', d, 'b_n:', f'{b_n:8b} {new_b_n:8b}')
        if new_b_n == 0:
            return x

        rv = sub_add(x_sign, (0x80 | x_n), b_sign, new_b_n, e)

    return rv

積和演算

ここまで来ると積和演算が簡単にできますよ(最終的には使わなかったりするのですが)。

mul_add.py
def mul_add(x:bit16, w:bit16, b:bit16):
    return add(mul(x, w), b)

その次のステップは Python でテストベンチを書いて動かすことです。テストベンチ?!

リンク集