Цветовая карта на основе столбцов с matplotlib.pyplot.imshow

0

Я хочу преобразовать свои входные и выходные векторы в массив и построить график с цветовой картой, которая отлично работает. См. Пример ниже:

x = np.round(np.array([-1, 0,     1,  2, 3, 2, 0, 0, 1, 2, 3]))
y = np.round(np.array([-1, 1.2, -3.1, 2, 3, 2, 1, 1, 1, 3, 3]))

xrange = np.arange(np.min(x),np.max(x) + 1)
yrange = np.arange(np.min(y),np.max(y) + 1)
a = np.zeros((len(yrange), len(xrange)))
for i in range(len(x)):
    a[yrange == y[i],xrange == x[i]] = a[yrange == y[i],xrange == x[i]] + 1

fig, ax = plt.subplots()
im = ax.imshow(a,cmap='Wistia',origin="lower")
for i in range(len(yrange)):
    for j in range(len(xrange)):
        text = ax.text(j, i, str(int(a[i, j])),
                       ha="center", va="center", color="k")
ax.set(xlabel = 'Input', ylabel = 'Output')
fig.tight_layout()
plt.show()

Что дает мне это:

введите описание изображения здесь

Единственная проблема в том, что мне нравится использовать палитру на основе столбцов. Это означает, что, например, в первом столбце я хочу, чтобы 1 считался наивысшим числом и имел тот же цвет, что и 3 во втором столбце, и это касается всех столбцов.