Quantcast
Channel: Google Developers Japan
Viewing all articles
Browse latest Browse all 2210

TensorFlow の Dataset と Estimator の紹介

$
0
0
この記事は TensorFlow チームによる Google Developers Blog の記事 "Introduction to TensorFlow Datasets and Estimators" を元に翻訳・加筆したものです。詳しくは元記事をご覧ください。

TensorFlow 1.3 では次の 2 つの重要な機能が導入されました。
  • Dataset: 入力パイプライン(プログラムにデータを読み込む部分)を作成する新しい API です
  • Estimator: TensorFlow モデルを作成する高レベル API です。一般的な機械学習タスク用に事前に作成されたモデルを提供します。独自のカスタムモデルも作成できます

以下の図は、TensorFlow アーキテクチャにおけるこれらの機能の位置付けを示します。この 2 つを組み合わせると、簡単な作業で TensorFlow モデルを作成して学習を行えます。




サンプルモデル


Dataset と Estimator の使い方の例として、簡単なモデルを作成するためのサンプルコードを紹介します。学習とテスト用のファイルの取得方法を含むコード全体はこちらで入手できます(なお、このコードは Dataset と Estimator がどのように機能するか示すためのもので、性能の最適化は考慮されていません)。

このモデルの学習を行うと、アヤメの花を 4 つの特徴(がく片の長さ、がく片の幅、花弁の長さ、花弁の幅)に基づいて分類するモデルが得られます。4 つの特徴の値をモデルに与えると、以下の 3 種類の変種のうちどれかを推論します。


左から順番にヒオウギアヤメRadomilより、CC BY-SA 3.0)、ブルーフラッグDlangloisより、CC BY-SA 3.0)、ヴァージニアアイリスFrank Mayfieldより、CC BY-SA 2.0)

ここでは、以下の構造のモデルを使い、ディープ ニューラル ネットワークによる分類器の学習を行います。すべての入力値と出力値は float32 となり、出力値の合計は 1 になります(個々のアヤメの種類についてその確率を予測するため)。


たとえば、ヒオウギアヤメの出力結果が 0.05、ブルーフラッグが 0.9、ヴァージニアアイリスが 0.05 という結果が得られた場合、アヤメがブルーフラッグである確率が 90% という意味になります。

これでモデルの定義が完了です。次に、Dataset と Estimator を使ってこのモデルによる学習と推論を行う方法を説明します。

Dataset の概要


Dataset は、TensorFlow モデルへの入力パイプラインを作成する新しい方法です。この API は feed_dict や Queue によるパイプラインより高性能で、より分かりやすく簡単に使用できます。TensorFlow 1.3 では Dataset はまだ tf.contrib.data にあり、コア API ではありませんが、1.4 ではコア API に昇格する予定です。Dataset を試し始めるには今がよい時機です。

Dataset は次のクラスで構成されます。


各クラスの説明は次のとおりです。
  • Dataset: Dataset を作成して変換するメソッドを含む基底クラスです。メモリ内のデータ、または Python ジェネレータから Dataset を初期化する機能を提供します
  • TextLineDataset: テキスト ファイルから行を読み取ります
  • TFRecordDataset: TFRecord ファイルからデータを読み取ります
  • FixedLengthRecordDataset: バイナリ ファイルから固定サイズのデータを読み取ります
  • Iterator: Dataset の各要素にひとつずつアクセスするために使います

サンプルコードのデータセット


まずはモデルの学習に使用するデータセットを見てみましょう。以下のような CSV ファイルを用います。各行には 5 つの値(4 つの入力値とラベル)が含まれます。

ラベルの定義は以下の通りです。
  • ヒオウギアヤメは 0
  • ブルーフラッグは 1
  • ヴァージニアアイリスは 2

Dataset を定義する


Dataset を定義するには、まず特徴量のリストを作成します。
feature_names = [
'SepalLength',
'SepalWidth',
'PetalLength',
'PetalWidth']

一方、モデルの学習を行うときは、入力ファイルを読み取って特徴量とラベルを返す入力関数が必要になります。次の形式の関数を作成し、Estimator に渡します。
def input_fn():
...<code>...
return ({ 'SepalLength':[values], ..<etc>.., 'PetalWidth':[values] },
[IrisFlowerType])

戻り値は、次の 2 要素のタプルである必要があります。
  • 最初の要素には辞書(dict)を返します。各入力特徴量のキーと、学習バッチの値のリストのペアです
  • 2 番目の要素は、学習バッチのラベルのリストとします

Estimator は、この return 文で返された値のリストとラベルのリストを用いて学習を行います。よって、これらはいずれも同じ長さとします。ちなみに、ここでいう「リスト」とは、1 次元の TensorFlow テンソルを指します。

この input_fn 関数を後から簡単に再利用できるように、いくつかの引数を追加します。これにより、設定を変えるだけで様々な入力関数を作成できます。これらの引数の意味は簡単です。
  • file_path: 読み取るデータファイルです
  • perform_shuffle: データをシャッフルするかどうかを指定します
  • repeat_count: Dataset の行の読み取りを繰り返す回数を指定します。たとえば 1 を指定すると、各行を 1 回だけ読みとる Dataset となります。None を指定すると、各行を繰り返し読み取り可能な Dataset となります

この入力関数を Dataset API で実装する方法を以下に示します。なお、この後で Estimator に入力関数を渡すとき、以下の関数をそのまま使わず、ラップした関数を使います。

def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
def decode_csv(line):
parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
label = parsed_line[-1:] # Last element is the label
del parsed_line[-1] # Delete last element
features = parsed_line # Everything (but last element) are the features
d = dict(zip(feature_names, features)), label
return d

dataset = (tf.contrib.data.TextLineDataset(file_path) # Read text file
.skip(1) # Skip header row
.map(decode_csv)) # Transform each elem by applying decode_csv fn
if perform_shuffle:
# Randomizes input using a window of 256 elements (read into memory)
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
dataset = dataset.batch(32) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels

ここでの要点は以下の通りです。
  • TextLineDataset: この Dataset API は、ファイルベースのデータセットを使用する際に必要となる様々なメモリ管理を行います。たとえば、メモリよりも大幅に大きいデータセットを読み込んだり、リストを引数として指定した複数のファイルを読み込んだりできます
  • shuffle: buffer_size で指定した数の要素を読み取り、その順番をシャッフルします
  • map: データセットの個々の要素を引数として decode_csv関数を呼び出します(この例では TextLineDataset を使用するため、つまり CSV テキストの各行を decode_csv に渡すことを意味します)。
  • decode_csv: 各行をフィールドに分割し、必要に応じてデフォルト値を指定します。つづいて、フィールドのキーと値の辞書を返します。上述の map 関数は、この辞書を使って各要素(行)を書き換えます

これで Dataset API の概要の説明は終わりです。この関数を使用して、学習バッチの冒頭部分を試してみましょう。

next_batch = my_input_fn(FILE, True) # Will return 32 random elements

# Now let's try it out, retrieving and printing one batch of data.
# Although this code looks strange, you don't need to understand
# the details.
with tf.Session() as sess:
first_batch = sess.run(next_batch)
print(first_batch)

# Output
({'SepalLength': array([ 5.4000001, ...<repeat to 32 elems>], dtype=float32),
'PetalWidth': array([ 0.40000001, ...<repeat to 32 elems>], dtype=float32),
...
},
[array([[2], ...<repeat to 32 elems>], dtype=int32) # Labels
)

このサンプルコードにおける Dataset の機能は以上です。Dataset API にはさらに多くの機能があります。詳細については、この記事の最後にあるリンク先をご覧ください。

Estimator の概要


Estimator は、TensorFlow モデルの学習時にこれまで必要だった手間のかかるコーディングが不要となる高レベル API です。柔軟性も高く、個々のモデルの要件に合わせてデフォルトの動作をオーバーライドできます。

Estimator は以下の 2 つの方法で作成できます。
  • 定義済み Estimator を使う - 特定の種類のモデルを生成するために事前に定義された Estimator です。この記事では、そのひとつである DNNClassifier を使う例を紹介します
  • Estimator クラス(基底クラス)を使う - model_fn 関数を使用して、モデルの作成方法をカスタマイズできる方法です。この方法については、また別の機会に紹介します
Estimator クラスの概要を以下に示します。


TensorFlow の今後のリリースでは、さらに多くの種類の定義済み Estimator を提供する予定です。

上図に示すように、いずれの Estimator も input_fn 関数を使用してデータを受け取ります。

以下のコードは、アヤメの種類を予測する Estimator を作成する例です。

# Create the feature_columns, which specifies the input to our model.
# All our input features are numeric, so use numeric_column for each one.
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]

# Create a deep neural network regression classifier.
# Use the DNNClassifier pre-made estimator
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, # The input features to our model
hidden_units=[10, 10], # Two layers, each with 10 neurons
n_classes=3,
model_dir=PATH) # Path to where checkpoints etc are stored

では、ここで作成した Estimator の学習を行いましょう。

モデルのトレーニング


以下のように一行書くだけで学習を行えます。

# Train our model, use the previously function my_input_fn
# Input to training is a file with training example
# Stop training after 8 iterations of train data (epochs)
classifier.train(
input_fn=lambda: my_input_fn(FILE_TRAIN, True, 8))

この「lambda: my_input_fn(FILE_TRAIN, True, 8)」という部分で、先程作成した Dataset とこの Estimator を結びつけています。Estimator は、モデルの学習と評価、推論に必要な入力データを input_fn 関数を介して取得します。この input_fn 関数としては、引数なしの関数を渡す必要があります。そこで、my_input_fn 関数の引数 file_path、shuffle setting、repeat_count として以下の値を指定する input_fn 関数を lambda で新たに作成し、Estimator に渡しています。
  • FILE_TRAIN: 学習用のデータファイルのパスです
  • True: データのシャッフルを指定します
  • 8: データセットを 8 回繰り返すよう指定します

学習したモデルの評価


これで、モデルの学習が終わりました。Estimator の evaluate メソッドを呼び出すと、モデルの精度を評価できます。

# Evaluate our model using the examples contained in FILE_TEST
# Return value will contain evaluation_metrics such as: loss & average_loss
evaluate_result = estimator.evaluate(
input_fn=lambda: my_input_fn(FILE_TEST, False, 4)
print("Evaluation results")
for key in evaluate_result:
print(" {}, was: {}".format(key, evaluate_result[key]))

今回のサンプルコードの例では、最大で約 93% の精度が得られます。この精度を高めるにはいろいろな方法が考えられますが、そのひとつは、単にこのコードを繰り返し実行することです。モデルの状態は上記で指定した model_dir=PATH に保存されるので、繰り返し学習を行うことで、モデルの精度は収束しながら向上していきます。

もう 1 つの方法は、モデルの隠れ層の数や各層のノード数を調整することです。これらの方法を自由に試せますが、変更すると DNNClassifierのモデルの構造が変化するため model_dir=PATHで指定したディレクトリを削除する必要があることに注意してください。

トレーニングされたモデルを使用した予測


ここまでの作業で、モデルの学習を行い、その精度も十分なものであることを評価しました。さっそくこのモデルを使い、アヤメの分類を試しましょう。学習や評価の場合と同様に、predict メソッドを呼び出すだけで予測が可能です。

# Predict the type of some Iris flowers.
# Let's predict the examples in FILE_TEST, repeat only once.
predict_results = classifier.predict(
input_fn=lambda: my_input_fn(FILE_TEST, False, 1))
print("Predictions on test file")
for prediction in predict_results:
# Will print the predicted class, i.e: 0, 1, or 2 if the prediction
# is Iris Sentosa, Vericolor, Virginica, respectively.
print prediction["class_ids"][0]

オンメモリのデータで予測

上記のコードでは、FILE_TESTを指定してファイル保存されたデータに対して予測を行いました。一方、オンメモリのデータなど、他の場所にあるデータに対して予測を行うにはどうすればよいでしょうか。この場合も、predict メソッドの呼び出しは変更する必要がありません。以下のように、オンメモリのデータを参照するように Dataset API を設定します。
# Let create a memory dataset for prediction.
# We've taken the first 3 examples in FILE_TEST.
prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor
[6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Sentosa
def new_input_fn():
def decode(x):
x = tf.split(x, 4) # Need to split into our 4 features
# When predicting, we don't need (or have) any labels
return dict(zip(feature_names, x)) # Then build a dict from them

# The from_tensor_slices function will use a memory structure as input
dataset = tf.contrib.data.Dataset.from_tensor_slices(prediction_input)
dataset = dataset.map(decode)
iterator = dataset.make_one_shot_iterator()
next_feature_batch = iterator.get_next()
return next_feature_batch, None # In prediction, we have no labels

# Predict all our prediction_input
predict_results = classifier.predict(input_fn=new_input_fn)

# Print results
print("Predictions on memory data")
for idx, prediction in enumerate(predict_results):
type = prediction["class_ids"][0] # Get the predicted class (index)
if type == 0:
print("I think: {}, is Iris Sentosa".format(prediction_input[idx]))
elif type == 1:
print("I think: {}, is Iris Versicolor".format(prediction_input[idx]))
else:
print("I think: {}, is Iris Virginica".format(prediction_input[idx])

Dataset.from_tensor_slides()は、オンメモリで保存できる小規模なデータセット用に設計されています。一方で、モデルの学習と評価の場合と同様に TextLineDatasetを使用すると、シャッフル用のバッファと学習バッチがメモリサイズに収まる限り、任意の大きなファイルを使って予測を行えます。

TensorBoard で可視化


ここまで見てきたとおり、DNNClassifier などの定義済み Estimator はその使いやすさなど大きなメリットを提供します。加えて、各種の評価値の計測機能も備えており、TensorBoard によるサマリー表示が可能です。このサマリーレポートを表示するには、コマンドラインから次のコードを実行して TensorBoard を起動します。

# Replace PATH with the actual path passed as model_dir argument when the
# DNNRegressor estimator was created.
tensorboard --logdir=PATH

TensorBoard には次のような評価値が表示されます。


まとめ


この記事では、TensorFlow の新しい API である Dataset と Estimator について説明しました。これらは入力データのストリームを定義してモデルを作成できる重要な API であり、その習得に必要なコストを大きく上回るメリットを提供します。

詳細については、次の各ページをご覧ください。

これで終わりではありません。これらの API がどのように動作するかを説明した投稿を間もなく公開しますので、ご期待ください。

それまで、TensorFlow のコーディングをお楽しみください。



Reviewed by Kaz Sato - Staff Developer Advocate, Google Cloud

Viewing all articles
Browse latest Browse all 2210

Latest Images

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

Trending Articles



Latest Images

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭

赤坂中華 わんたん亭