tf.estimator.DNNClassifier构造函数具有
weight_column参数:
weight_column
:字符串或_NumericColumn通过tf.feature_column.numeric_column定义表示权重的特征列而创建的
。在训练过程中,它可用于减轻体重或增强示例效果。它将乘以示例的损失。如果是字符串,则用作从中获取权重张量的键features。如果为_NumericColumn,则通过key获取原始张量weight_column.key,然后weight_column.normalizer_fn将其应用于权重张量。
因此,只需添加一个新列并为稀有类填充一些权重即可:
weight = tf.feature_column.numeric_column('weight')...tf.estimator.DNNClassifier(..., weight_column=weight)
[更新] 这是一个完整的工作示例:
import numpy as npimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('mnist', one_hot=False)train_x, train_y = mnist.train.next_batch(1024)test_x, test_y = mnist.test.images, mnist.test.labelsx_column = tf.feature_column.numeric_column('x', shape=[784])weight_column = tf.feature_column.numeric_column('weight')classifier = tf.estimator.DNNClassifier(feature_columns=[x_column], hidden_units=[100, 100], weight_column=weight_column, n_classes=10)# Trainingtrain_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': train_x, 'weight': np.ones(train_x.shape[0])}, y=train_y.astype(np.int32), num_epochs=None, shuffle=True)classifier.train(input_fn=train_input_fn, steps=1000)# Testingtest_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': test_x, 'weight': np.ones(test_x.shape[0])}, y=test_y.astype(np.int32), num_epochs=1, shuffle=False)acc = classifier.evaluate(input_fn=test_input_fn)print('Test Accuracy: %.3f' % acc['accuracy'])
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)