dondakeshimoの丸太

データサイエンス/Webアプリケーション/

TensorFlow Object Detection API をしこたま教育してやるわ[TFRecord編]

TensorFlow Object Detection

went-went-takkun135.hatenablog.com

この記事の続きです。 独自のデータセットで学習を行いたいから, このモデルに学習させる方法を調べてこいとのことだったので,

github.com

上記のチュートリアルを変形しながら進めていきたいと思います。

Data setの用意

まずはデータセットを用意しなければ始まりません。 チュートリアルの通りデータをダウンロードしてもいいのですが, めちゃくちゃ時間がかかるのが難点です。 自分で用意してもいいんですが, 形式がわからないと思うので下記においておきます。 チュートリアルでダウンロード予定だったannotationsの中にあるxmlのファイルです。

<annotation>
  <folder>dataset</folder>
  <filename>shiba_inu_11.jpg</filename>
  <source>
    <annotation>OXIIIT</annotation>
  </source>
  <size>
    <width>500</width>
    <height>334</height>
    <depth>3</depth>
  </size>
  <object>
    <name>dog</name>
    <pose>Frontal</pose>
    <truncated>0</truncated>
    <occluded>0</occluded>
    <bndbox>
      <xmin>140</xmin>
      <ymin>1</ymin>
      <xmax>384</xmax>
      <ymax>235</ymax>
    </bndbox>
  </object>
</annotation>

これがshiba_inu_11.jpgの画像ファイルに1対1対応するxmlファイルになります。

  • 画像ファイルをimages/
  • xmlファイルをannotations/
  • 二つのディレクトリをdataset/

に入れておくといいと思います。 そのほかにもペットのannotationsをダウンロードしてくれば色々入ってくるんですが, とりあえずこの形式に沿っていくつかデータあれば大丈夫だと思います。 要はPythonのlxmlライブラリでこれらの値をdictionaryとして格納して, それをTFRecordとして取り出すのが今回の流れです。

create_pet_tf_record.pyの変更

置換などはvimのコマンドで書いていこうかと思います(すでにAtomに乗り換え済み)

対象変更

とりあえず,今回はpetじゃなくて人間を検出したかったため,

:%s/pet/person/gc

しました。とりあえず。笑

画像形式変更

次に,今回のデータセットで用いる画像形式がjpgではなくpngであったため

:%s/jpg/png/gc
:%s/JPEG/PNG/gc

しました。TensorFlowではjpgかpngが使えるらしいですね。

不要なデータ部分を消去

筆者の用意したデータセットには

  • difficult
  • pose
  • truncated

の3項目が入っていなかったため,既存のコードではエラーを吐かれました。そのため,該当箇所をコメントアウトする必要がありました。(データセットにそれらの情報を加えるのもありだと思いますが,めんどいとも思います)

line 109~145

for obj in data['object']:
    difficult = bool(int(obj['difficult']))
  if ignore_difficult_instances and difficult:
    continue

  difficult_obj.append(int(difficult))

  xmin.append(float(obj['bndbox']['xmin']) / width)
  ymin.append(float(obj['bndbox']['ymin']) / height)
  xmax.append(float(obj['bndbox']['xmax']) / width)
  ymax.append(float(obj['bndbox']['ymax']) / height)
  class_name = get_class_name_from_filename(data['filename'])
  classes_text.append(class_name.encode('utf8'))
  classes.append(label_map_dict[class_name])
  truncated.append(int(obj['truncated']))
  poses.append(obj['pose'].encode('utf8'))

example = tf.train.Example(features=tf.train.Features(feature={
    'image/height': dataset_util.int64_feature(height),
    'image/width': dataset_util.int64_feature(width),
    'image/filename': dataset_util.bytes_feature(
        data['filename'].encode('utf8')),
    'image/source_id': dataset_util.bytes_feature(
        data['filename'].encode('utf8')),
    'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
    'image/encoded': dataset_util.bytes_feature(encoded_jpg),
    'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
    'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
    'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
    'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
    'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
    'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
    'image/object/class/label': dataset_util.int64_list_feature(classes),
    'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
    'image/object/truncated': dataset_util.int64_list_feature(truncated),
    'image/object/view': dataset_util.bytes_list_feature(poses),
}))

これを↓に変更

for obj in data['object']:
  # difficult = bool(int(obj['difficult']))
  # if ignore_difficult_instances and difficult:
  #   continue

  # difficult_obj.append(int(difficult))

  xmin.append(float(obj['bndbox']['xmin']) / width)
  ymin.append(float(obj['bndbox']['ymin']) / height)
  xmax.append(float(obj['bndbox']['xmax']) / width)
  ymax.append(float(obj['bndbox']['ymax']) / height)
  class_name = get_class_name_from_filename(data['filename'])
  classes_text.append(class_name.encode('utf8'))
  classes.append(label_map_dict[class_name])
  # truncated.append(int(obj['truncated']))
  # poses.append(obj['pose'].encode('utf8'))

example = tf.train.Example(features=tf.train.Features(feature={
    'image/height': dataset_util.int64_feature(height),
    'image/width': dataset_util.int64_feature(width),
    'image/filename': dataset_util.bytes_feature(
        data['filename'].encode('utf8')),
    'image/source_id': dataset_util.bytes_feature(
        data['filename'].encode('utf8')),
    'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
    'image/encoded': dataset_util.bytes_feature(encoded_png),
    'image/format': dataset_util.bytes_feature('png'.encode('utf8')),
    'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
    'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
    'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
    'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
    'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
    'image/object/class/label': dataset_util.int64_list_feature(classes),
  #   'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
  #   'image/object/truncated': dataset_util.int64_list_feature(truncated),
  #   'image/object/view': dataset_util.bytes_list_feature(poses),

データセットへのpath指定を細々と変更

xmlまでのpath
path = os.path.join(annotations_dir, 'xmls', example + '.xml')

上記の'xmls',が邪魔ですので消去です。 petのデータセットではannotationsの下にxmlsというディレクトリがあって, その中にxmlファイルが入っているらしいですね。

各データへのpath

line 191~192

examples_path = os.path.join(annotations_dir, 'trainval.txt')
examples_list = dataset_util.read_examples_list(examples_path)

ここを↓に変更

xml_list = os.listdir(annotations_dir)
examples_list = []
for xml_name in xml_list:
  temp_splited_name = xml_name.split(".")
  examples_list.append(temp_splited_name[0])

また,各データの名前を入れたテキストファイルも用意していなかったので, Pythonに頑張ってもらいました。

実行

多分上記の変更でTFRecord作れるようになっていると思います! 今,このTFRecord使って学習させようと奮闘中なのですが, なーんか時間かかりすぎてて絶対間違ってるなーって状態です笑

次回,奮闘の末に勝利を納めて入れば学習のセットアップまで書きます!