python - Tensorflow Update first matching element in each row -
building on question looking update values of 2-d tensor first time in row tf.where condition met. here sample code using simulate:
tf.reset_default_graph() graph = tf.graph() graph.as_default(): val = "hello" new_val = "goodbye" matrix = tf.constant([["word","hello","hello"], ["word", "other", "hello"], ["hello", "hello","hello"], ["word", "word", "word"] ]) matching_indices = tf.where(tf.equal(matrix, val)) first_matching_idx = tf.segment_min(data = matching_indices[:, 1], segment_ids = matching_indices[:, 0]) sess = tf.interactivesession(graph=graph) print(sess.run(first_matching_idx))
this output [1, 2, 0] 1 placement of first hello in row 1, 2 placement of first hello in row 2, , 0 placement of first hello in row 3.
however, can't figure out way first matching index updated new value -- want first "hello" turned "goodbye". have tried using tf.scatter_update() not seem work on 2d tensors. there way modify 2-d tensor described?
one easy workaround use tf.py_func
numpy array
def ch_val(array, val, new_val): idx = np.array([[s, list(row).index(val)] s, row in enumerate(array) if val in row]) idx = tuple((idx[:, 0], idx[:, 1])) array[idx] = new_val return array ... matrix = tf.variable([["word","hello","hello"], ["word", "other", "hello"], ["hello", "hello","hello"], ["word", "word", "word"] ]) matrix = tf.py_func(ch_val, [matrix, 'hello', 'goodbye'], tf.string) tf.session() sess: sess.run(tf.global_variables_initializer()) print(sess.run(matrix)) # results: [['word' 'goodbye' 'hello'] ['word' 'other' 'goodbye'] ['goodbye' 'hello' 'hello'] ['word' 'word' 'word']]
Comments
Post a Comment