how to get the largest item in each row with numpy?

2

I try to get a list of the indexes of the largest elements in each row

a = np.arange(16).reshape((4, 4))
print(a)
print(np.amax(a, axis=1))
print(np.where(a == np.amax(a, axis=1)))

this is the result:

I get the largest matrix where .

in this case the result list that I want to obtain would be:

[3, 3, 3, 3]

which is in the column that is the element

    
asked by Luis Alberto Acosta 20.11.2017 в 05:33
source

1 answer

3

numpy implements the argmax function that returns the smallest index of the largest element in a row, column or the entire array, as you want the largest element in a row you must use axis = 1, for example:

import numpy as np

a = np.random.randint(20, size=16).reshape((4, 4))
print(a)

print(np.argmax(a, axis=1))

Output:

[[10 15  2  3]
 [ 9  0 10 19]
 [ 3 18 17 18]
 [ 5 15  4  8]]
[1 3 1 1]
    
answered by 20.11.2017 / 05:46
source