TensorFlow Object Detection API をしこたま教育してやるわ[TFRecord編]
TensorFlow Object Detection
went-went-takkun135.hatenablog.com
この記事の続きです。 独自のデータセットで学習を行いたいから, このモデルに学習させる方法を調べてこいとのことだったので,
上記のチュートリアルを変形しながら進めていきたいと思います。
Data setの用意
まずはデータセットを用意しなければ始まりません。 チュートリアルの通りデータをダウンロードしてもいいのですが, めちゃくちゃ時間がかかるのが難点です。 自分で用意してもいいんですが, 形式がわからないと思うので下記においておきます。 チュートリアルでダウンロード予定だったannotationsの中にあるxmlのファイルです。
<annotation>
<folder>dataset</folder>
<filename>shiba_inu_11.jpg</filename>
<source>
<annotation>OXIIIT</annotation>
</source>
<size>
<width>500</width>
<height>334</height>
<depth>3</depth>
</size>
<object>
<name>dog</name>
<pose>Frontal</pose>
<truncated>0</truncated>
<occluded>0</occluded>
<bndbox>
<xmin>140</xmin>
<ymin>1</ymin>
<xmax>384</xmax>
<ymax>235</ymax>
</bndbox>
</object>
</annotation>
これがshiba_inu_11.jpgの画像ファイルに1対1対応するxmlファイルになります。
に入れておくといいと思います。 そのほかにもペットのannotationsをダウンロードしてくれば色々入ってくるんですが, とりあえずこの形式に沿っていくつかデータあれば大丈夫だと思います。 要はPythonのlxmlライブラリでこれらの値をdictionaryとして格納して, それをTFRecordとして取り出すのが今回の流れです。
create_pet_tf_record.pyの変更
置換などはvimのコマンドで書いていこうかと思います(すでにAtomに乗り換え済み)
対象変更
とりあえず,今回はpetじゃなくて人間を検出したかったため,
:%s/pet/person/gc
しました。とりあえず。笑
画像形式変更
次に,今回のデータセットで用いる画像形式がjpgではなくpngであったため
:%s/jpg/png/gc
:%s/JPEG/PNG/gc
しました。TensorFlowではjpgかpngが使えるらしいですね。
不要なデータ部分を消去
筆者の用意したデータセットには
- difficult
- pose
- truncated
の3項目が入っていなかったため,既存のコードではエラーを吐かれました。そのため,該当箇所をコメントアウトする必要がありました。(データセットにそれらの情報を加えるのもありだと思いますが,めんどいとも思います)
line 109~145
for obj in data['object']:
difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult:
continue
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
class_name = get_class_name_from_filename(data['filename'])
classes_text.append(class_name.encode('utf8'))
classes.append(label_map_dict[class_name])
truncated.append(int(obj['truncated']))
poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
これを↓に変更
for obj in data['object']:
# difficult = bool(int(obj['difficult']))
# if ignore_difficult_instances and difficult:
# continue
# difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
class_name = get_class_name_from_filename(data['filename'])
classes_text.append(class_name.encode('utf8'))
classes.append(label_map_dict[class_name])
# truncated.append(int(obj['truncated']))
# poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_png),
'image/format': dataset_util.bytes_feature('png'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
# 'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
# 'image/object/truncated': dataset_util.int64_list_feature(truncated),
# 'image/object/view': dataset_util.bytes_list_feature(poses),
データセットへのpath指定を細々と変更
xmlまでのpath
path = os.path.join(annotations_dir, 'xmls', example + '.xml')
上記の'xmls',が邪魔ですので消去です。 petのデータセットではannotationsの下にxmlsというディレクトリがあって, その中にxmlファイルが入っているらしいですね。
各データへのpath
line 191~192
examples_path = os.path.join(annotations_dir, 'trainval.txt')
examples_list = dataset_util.read_examples_list(examples_path)
ここを↓に変更
xml_list = os.listdir(annotations_dir)
examples_list = []
for xml_name in xml_list:
temp_splited_name = xml_name.split(".")
examples_list.append(temp_splited_name[0])
また,各データの名前を入れたテキストファイルも用意していなかったので, Pythonに頑張ってもらいました。
実行
多分上記の変更でTFRecord作れるようになっていると思います! 今,このTFRecord使って学習させようと奮闘中なのですが, なーんか時間かかりすぎてて絶対間違ってるなーって状態です笑
次回,奮闘の末に勝利を納めて入れば学習のセットアップまで書きます!