3D plotting in Python - Adding a Legend to Scatterplot

83
November 30, 2019, at 8:30 PM
from mpl_toolkits.mplot3d import Axes3D
ax.scatter(X_lda[:,0], X_lda[:,1], X_lda[:,2], alpha=0.4, c=y_train, cmap='rainbow', s=20)
plt.legend()
plt.show()

Essentially I'd like to add a legend for the scatterplot that shows the unique values in y_train and what color point it corresponds to on the plot.

The output plot:

Answer 1

Producing either a legend or a colorbar for a scatter is usually quite simple:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
x,y,z = (np.random.normal(size=(300,4))+np.array([0,2,4,6])).reshape(3,400)
c = np.tile([1,2,3,4], 100)
fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
sc = ax.scatter(x,y,z, alpha=0.4, c=c, cmap='rainbow', s=20)
plt.legend(*sc.legend_elements())
plt.colorbar(sc)
plt.show()

Answer 2

Edit: After seeing @bigreddot's solution, I agree that this approach is somewhat more complicated than strictly necessary. I leave it here in case somebody needs more fine-tuning for their colorbar or legend.

Here is a way to create both a custom legend and a custom colorbar for the 3D graph. So you can chose one or the other, depending on specific needs. I'm not sure how the y_train is distributed; in the code some float values of a limited set are simulated. Also, it is not clear what the labels should mention, so now they just put the value of y_train.

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib as mpl
N = 1000
X_lda = np.random.gamma(9.0, 0.5, (N,3))
y_train = np.random.randint(0, 6, N)
X0 = np.random.gamma(5.0, 1.5, (N,3))
X1 = np.random.gamma(1.0, 1.5, (N,3))
for i in range(3):
    X_lda[:,i] = np.where (y_train == 0, X0[:,i], X_lda[:,i])
    X_lda[:,i] = np.where (y_train == 1, X1[:,i], X_lda[:,i])
y_train = np.sin(y_train*.2 + 10) * 10.0 + 20.0
fig = plt.figure(figsize = (15,15))
ax = fig.add_subplot(111, projection = '3d')
ax.scatter(X_lda[:,0], X_lda[:,1], X_lda[:,2], alpha=0.4, c=y_train, cmap='rainbow', s=20)
norm = mpl.colors.Normalize(np.min(y_train), np.max(y_train))
cmap = plt.get_cmap('rainbow')
y_unique = np.unique(y_train)
legend_lines = [mpl.lines.Line2D([0],[0], linestyle="none", marker='o', c=cmap(norm(y))) for y in y_unique]
legend_labels = [f'{y:.2f}' for y in y_unique]
ax.legend(legend_lines, legend_labels, numpoints = 1, title='Y-train')
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
plt.colorbar(sm, ticks=y_unique, label='Y-train')
plt.show()

READ ALSO
Get combobox value to feed into button command in python tkinter

Get combobox value to feed into button command in python tkinter

so I'm learning tkinter by building a "weightlifting GUI" for funAs it stands, the application mostly runs

80
Live graph in python with a fixed window

Live graph in python with a fixed window

I am trying to construct a live graph in python to plot random numbers in a graph with a fixed plot window The width of the plot window will be 20 samplesFor the 21st sample, the 1st sample will dissappear from the right side

93
How to subtract dictionaries from a list

How to subtract dictionaries from a list

I have a list of dictionaries (not all in the same length) like -

96
How can I store these values?

How can I store these values?

I'm learning how to code in Python and I have this problem: I want to store the output values of a function (function2) that depends on another function (Function1)I want to store them in order to compute its mean, variance, etc

83