FC2カウンター FPGAの部屋 MNISTデータセットの一部をC の配列に変換するPython コードの更新

FPGAやCPLDの話題やFPGA用のツールの話題などです。 マニアックです。 日記も書きます。

FPGAの部屋

FPGAの部屋の有用と思われるコンテンツのまとめサイトを作りました。Xilinx ISEの初心者の方には、FPGAリテラシーおよびチュートリアルのページをお勧めいたします。

MNISTデータセットの一部をC の配列に変換するPython コードの更新

「ゼロから作るDeep Learning」の2層ニューラルネットワークのハードウェア化3”に貼った、「Vivado HLSのテストベンチに必要なMNISTデータセットの一部をC の配列に変換するPython コード」を更新した。
10000個のMNISTのテストデータの任意の位置の100個を抽出できるように変更したので、貼っておく。
(2017/08/27 : バグがあったので修正)

# MNISTのデータをCの配列に出力し、ファイルに書き込み

# coding: utf-8
import sys, os
sys.path.append(os.pardir)

import numpy as np
from dataset.mnist import load_mnist
import datetime

OUTPUT_DATA_NUM = 100 # 出力するMNISTのテストデータ数 10000までの数
OFFSET = 100 # MNISTデータセットのオフセット、100だったら100番目からOUTPUT_DATA_NUM個を出力する

# データの読み込み
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)

f = open("mnist_data.h", 'w')
todaytime = datetime.datetime.today()
f.write('// mnist_data.h\n')
strdtime = todaytime.strftime("%Y/%m/%d %H:%M:%S")
f.write('// {0} by marsee\n'.format(strdtime))
f.write("\n")

f.write('ap_ufixed<8, 0, AP_TRN_ZERO, AP_SAT> t_train['+str(OUTPUT_DATA_NUM)+']['+str(x_test.shape[1])+'] = {\n')
for i in range(OFFSET, OFFSET+OUTPUT_DATA_NUM):
    f.write("\t{")
    for j in range(x_test.shape[1]):
        f.write(str(x_test[i][j]))
        if (j==x_test.shape[1]-1):
            if (i==OUTPUT_DATA_NUM-1):
                f.write("}\n")
            else:
                f.write("},\n")
        else:
            f.write(", ")
f.write("};\n")

f.write('int t_train_256['+str(OUTPUT_DATA_NUM)+']['+str(x_test.shape[1])+'] = {\n')
for i in range(OFFSET, OFFSET+OUTPUT_DATA_NUM):
    f.write("\t{")
    for j in range(x_test.shape[1]):
        f.write(str(int(x_test[i][j]*256)))
        if (j==x_test.shape[1]-1):
            if (i==OUTPUT_DATA_NUM-1):
                f.write("}\n")
            else:
                f.write("},\n")
        else:
            f.write(", ")
f.write("};\n")

f.write("\n")
f.write('float t_test['+str(OUTPUT_DATA_NUM)+']['+str(t_test.shape[1])+'] = {\n')
for i in range(OFFSET, OFFSET+OUTPUT_DATA_NUM):
    f.write("\t{")
    for j in range(t_test.shape[1]):
        f.write(str(t_test[i][j]))
        if (j==t_test.shape[1]-1):
            if (i==OUTPUT_DATA_NUM-1):
                f.write("}\n")
            else:
                f.write("},\n")
        else:
            f.write(", ")
f.write("};\n")
f.close() 

  1. 2017年06月06日 04:45 |
  2. DNN
  3. | トラックバック:0
  4. | コメント:0

コメント

コメントの投稿


管理者にだけ表示を許可する

トラックバック URL
http://marsee101.blog19.fc2.com/tb.php/3822-8cbb79c5
この記事にトラックバックする(FC2ブログユーザー)