Skip to content

Commit 4c94bee

Browse files
committed
Enable again to plot states with RGBA tuple color on the Bloch sphere
1 parent 87b70ed commit 4c94bee

File tree

3 files changed

+39
-21
lines changed

3 files changed

+39
-21
lines changed

doc/changes/2678.bugfix

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Enable again to plot a state as a point on the Bloch sphere with a RGBA tuple.
2+
Enable to do the same as a vector.

qutip/bloch.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44
from typing import Literal
55

66
import numpy as np
7-
from numpy import (outer, cos, sin, ones)
8-
7+
from numpy import cos, ones, outer, sin
98
from packaging.version import parse as parse_version
109

1110
from . import Qobj, expect, sigmax, sigmay, sigmaz
1211

1312
try:
1413
import matplotlib
1514
import matplotlib.pyplot as plt
16-
from mpl_toolkits.mplot3d import Axes3D
1715
from matplotlib.patches import FancyArrowPatch
18-
from mpl_toolkits.mplot3d import proj3d
16+
from mpl_toolkits.mplot3d import Axes3D, proj3d
1917

2018
# Define a custom _axes3D function based on the matplotlib version.
2119
# The auto_add_to_figure keyword is new for matplotlib>=3.4.
@@ -385,8 +383,9 @@ def add_states(self, state: Qobj,
385383
kind : {'vector', 'point'}
386384
Type of object to plot.
387385
388-
colors : array_like
386+
colors : str or array_like
389387
Optional array with colors for the states.
388+
The colors can be a string or a RGB or RGBA tuple.
390389
391390
alpha : float, default=1.
392391
Transparency value for the vectors. Values between 0 and 1.
@@ -404,14 +403,18 @@ def add_states(self, state: Qobj,
404403
colors = np.asarray(colors)
405404

406405
if colors.ndim == 0:
406+
colors = np.repeat(colors, state.shape[0])
407+
408+
elif colors.ndim == 1 and np.isdtype(colors.dtype, ("integral", "real floating")):
407409
colors = colors[np.newaxis]
408-
409-
if colors.shape != state.shape:
410+
colors = np.repeat(colors, [state.shape[0]], axis=0)
411+
412+
if colors.shape[0] != state.shape[0]:
410413
raise ValueError("The included colors are not valid. "
411-
"colors must be equivalent to a 1D array "
412-
"with the same size as the number of states.")
414+
"colors must have the same size as state.")
415+
413416
else:
414-
colors = np.array([None] * state.size)
417+
colors = np.array([None] * state.shape[0])
415418

416419
for k, st in enumerate(state):
417420
vec = _state_to_cartesian_coordinates(st)
@@ -420,6 +423,9 @@ def add_states(self, state: Qobj,
420423
self.add_vectors(vec, colors=[colors[k]], alpha=alpha)
421424
elif kind == 'point':
422425
self.add_points(vec, colors=[colors[k]], alpha=alpha)
426+
else:
427+
raise ValueError("The included kind is not valid. "
428+
f"It should be vector or point, not {kind}.")
423429

424430
def add_vectors(self, vectors, colors=None, alpha=1.0):
425431
"""Add a list of vectors to Bloch sphere.
@@ -429,8 +435,9 @@ def add_vectors(self, vectors, colors=None, alpha=1.0):
429435
vectors : array_like
430436
Array with vectors of unit length or smaller.
431437
432-
colors : array_like
438+
colors : str or array_like
433439
Optional array with colors for the vectors.
440+
The colors can be a string or a RGB or RGBA tuple.
434441
435442
alpha : float, default=1.
436443
Transparency value for the vectors. Values between 0 and 1.
@@ -448,16 +455,20 @@ def add_vectors(self, vectors, colors=None, alpha=1.0):
448455
"index represents the iteration over the vectors and the "
449456
"second index represents the position in 3D of vector head.")
450457

451-
n_vectors = vectors.shape[0]
452458
if colors is None:
453-
colors = np.array([None] * n_vectors)
459+
colors = np.array([None] * vectors.shape[0])
454460
else:
455461
colors = np.asarray(colors)
456462

457-
if colors.ndim != 1 or colors.size != n_vectors:
458-
raise ValueError("The included colors are not valid. colors must "
459-
"be equivalent to a 1D array with the same "
460-
"size as the number of vectors. ")
463+
if colors.ndim == 0:
464+
colors = np.repeat(colors, vectors.shape[0])
465+
466+
if (
467+
colors.shape[0] != vectors.shape[0]
468+
or colors.ndim == 2 and not np.isdtype(colors.dtype, ("integral", "real floating"))
469+
):
470+
raise ValueError("The included colors are not valid. "
471+
"colors must have the same size as vectors.")
461472

462473
for k, vec in enumerate(vectors):
463474
self.vectors.append(vec)

qutip/tests/test_bloch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,14 +488,20 @@ def plot_vector_ref(self, fig, vector_kws):
488488
dict(vectors=[(1, 0, 1), (1, 1, 0)], alpha=0.5),
489489
], id="alpha-multiple-vector-sets"),
490490
pytest.param(
491-
dict(vectors=(0, 0, 1), colors=['y']), id="color-y"),
491+
dict(vectors=(0, 0, 1), colors=['#4f52cc']), id="color-hex"),
492492
pytest.param(
493493
dict(vectors=[(0, 0, 1), (0, 1, 0)], colors=['y', 'y']),
494494
id="color-two-y"),
495495
pytest.param([
496496
dict(vectors=[(0, 0, 1)], colors=['y']),
497497
dict(vectors=[(1, 0, 1)], colors=['g']),
498498
], id="color-yg"),
499+
pytest.param(
500+
dict(vectors=[(0, 0, 1), (0, 1, 0)], colors=[(0.4, 0.7, 0.5), (0.1, 0.2, 0.8)]),
501+
id="color-RGB"),
502+
pytest.param(
503+
dict(vectors=[(0, 0, 1), (0, 1, 0)], colors=[(0.4, 0.7, 0.5, 0.9), (0.1, 0.2, 0.8, 0.4)]),
504+
id="color-RGBA"),
499505
])
500506
@check_pngs_equal
501507
def test_vector(self, vector_kws, fig_test, fig_ref):
@@ -536,9 +542,8 @@ def test_vector_errors_color_length(self, vectors, colors):
536542
b.add_vectors(vectors, colors=colors)
537543
b.render()
538544

539-
err_msg = ("The included colors are not valid. colors must "
540-
"be equivalent to a 1D array with the same "
541-
"size as the number of vectors. ")
545+
err_msg = ("The included colors are not valid. "
546+
"colors must have the same size as vectors.")
542547
assert str(err.value) == err_msg
543548

544549
@check_pngs_equal

0 commit comments

Comments
 (0)