44from typing import Literal
55
66import numpy as np
7- from numpy import (outer , cos , sin , ones )
8-
7+ from numpy import cos , ones , outer , sin
98from packaging .version import parse as parse_version
109
1110from . import Qobj , expect , sigmax , sigmay , sigmaz
1211
1312try :
1413 import matplotlib
1514 import matplotlib .pyplot as plt
16- from mpl_toolkits .mplot3d import Axes3D
1715 from matplotlib .patches import FancyArrowPatch
18- from mpl_toolkits .mplot3d import proj3d
16+ from mpl_toolkits .mplot3d import Axes3D , proj3d
1917
2018 # Define a custom _axes3D function based on the matplotlib version.
2119 # The auto_add_to_figure keyword is new for matplotlib>=3.4.
@@ -385,8 +383,9 @@ def add_states(self, state: Qobj,
385383 kind : {'vector', 'point'}
386384 Type of object to plot.
387385
388- colors : array_like
386+ colors : str or array_like
389387 Optional array with colors for the states.
388+ The colors can be a string or a RGB or RGBA tuple.
390389
391390 alpha : float, default=1.
392391 Transparency value for the vectors. Values between 0 and 1.
@@ -404,14 +403,18 @@ def add_states(self, state: Qobj,
404403 colors = np .asarray (colors )
405404
406405 if colors .ndim == 0 :
406+ colors = np .repeat (colors , state .shape [0 ])
407+
408+ elif colors .ndim == 1 and np .isdtype (colors .dtype , ("integral" , "real floating" )):
407409 colors = colors [np .newaxis ]
408-
409- if colors .shape != state .shape :
410+ colors = np .repeat (colors , [state .shape [0 ]], axis = 0 )
411+
412+ if colors .shape [0 ] != state .shape [0 ]:
410413 raise ValueError ("The included colors are not valid. "
411- "colors must be equivalent to a 1D array "
412- "with the same size as the number of states." )
414+ "colors must have the same size as state." )
415+
413416 else :
414- colors = np .array ([None ] * state .size )
417+ colors = np .array ([None ] * state .shape [ 0 ] )
415418
416419 for k , st in enumerate (state ):
417420 vec = _state_to_cartesian_coordinates (st )
@@ -420,6 +423,9 @@ def add_states(self, state: Qobj,
420423 self .add_vectors (vec , colors = [colors [k ]], alpha = alpha )
421424 elif kind == 'point' :
422425 self .add_points (vec , colors = [colors [k ]], alpha = alpha )
426+ else :
427+ raise ValueError ("The included kind is not valid. "
428+ f"It should be vector or point, not { kind } ." )
423429
424430 def add_vectors (self , vectors , colors = None , alpha = 1.0 ):
425431 """Add a list of vectors to Bloch sphere.
@@ -429,8 +435,9 @@ def add_vectors(self, vectors, colors=None, alpha=1.0):
429435 vectors : array_like
430436 Array with vectors of unit length or smaller.
431437
432- colors : array_like
438+ colors : str or array_like
433439 Optional array with colors for the vectors.
440+ The colors can be a string or a RGB or RGBA tuple.
434441
435442 alpha : float, default=1.
436443 Transparency value for the vectors. Values between 0 and 1.
@@ -448,16 +455,20 @@ def add_vectors(self, vectors, colors=None, alpha=1.0):
448455 "index represents the iteration over the vectors and the "
449456 "second index represents the position in 3D of vector head." )
450457
451- n_vectors = vectors .shape [0 ]
452458 if colors is None :
453- colors = np .array ([None ] * n_vectors )
459+ colors = np .array ([None ] * vectors . shape [ 0 ] )
454460 else :
455461 colors = np .asarray (colors )
456462
457- if colors .ndim != 1 or colors .size != n_vectors :
458- raise ValueError ("The included colors are not valid. colors must "
459- "be equivalent to a 1D array with the same "
460- "size as the number of vectors. " )
463+ if colors .ndim == 0 :
464+ colors = np .repeat (colors , vectors .shape [0 ])
465+
466+ if (
467+ colors .shape [0 ] != vectors .shape [0 ]
468+ or colors .ndim == 2 and not np .isdtype (colors .dtype , ("integral" , "real floating" ))
469+ ):
470+ raise ValueError ("The included colors are not valid. "
471+ "colors must have the same size as vectors." )
461472
462473 for k , vec in enumerate (vectors ):
463474 self .vectors .append (vec )
0 commit comments