2222from  pennylane .exceptions  import  MeasurementShapeError , QuantumFunctionError 
2323from  pennylane .operation  import  Operator 
2424from  pennylane .queuing  import  QueuingManager 
25- from  pennylane .wires  import  Wires 
25+ from  pennylane .typing  import  TensorLike 
26+ from  pennylane .wires  import  Wires , WiresLike 
2627
2728from  .counts  import  CountsMP 
2829from  .measurements  import  SampleMeasurement 
@@ -46,11 +47,15 @@ class SampleMP(SampleMeasurement):
4647            This can only be specified if an observable was not provided. 
4748        id (str): custom label given to a measurement instance, can be useful for some applications 
4849            where the instance has to be identified 
50+         dtype (str or None): The dtype of the samples returned by this measurement process. 
4951    """ 
5052
5153    _shortname  =  "sample" 
5254
53-     def  __init__ (self , obs = None , wires = None , eigvals = None , id = None ):
55+     # pylint: disable=too-many-arguments 
56+     def  __init__ (self , obs = None , wires = None , eigvals = None , id = None , dtype = None ):
57+ 
58+         self ._dtype  =  dtype 
5459
5560        if  isinstance (obs , MeasurementValue ):
5661            super ().__init__ (obs = obs )
@@ -86,7 +91,7 @@ def _abstract_eval(
8691        has_eigvals = False ,
8792        shots : int  |  None  =  None ,
8893        num_device_wires : int  =  0 ,
89-     ):
94+     )  ->   tuple [ tuple [ int , ...],  type ] :
9095        if  shots  is  None :
9196            raise  ValueError ("finite shots are required to use SampleMP" )
9297        sample_eigvals  =  n_wires  is  None  or  has_eigvals 
@@ -99,6 +104,8 @@ def _abstract_eval(
99104
100105    @property  
101106    def  numeric_type (self ):
107+         if  self ._dtype  is  not   None :
108+             return  self ._dtype 
102109        if  self .obs  is  None :
103110            # Computational basis samples 
104111            return  int 
@@ -123,15 +130,16 @@ def shape(self, shots: int | None = None, num_device_wires: int = 0) -> tuple:
123130    def  process_samples (
124131        self ,
125132        samples : Sequence [complex ],
126-         wire_order : Wires ,
133+         wire_order : WiresLike ,
127134        shot_range : None  |  tuple [int , ...] =  None ,
128135        bin_size : None  |  int  =  None ,
129-     ):
136+     ) ->  TensorLike :
137+ 
130138        return  process_raw_samples (
131-             self , samples , wire_order , shot_range = shot_range , bin_size = bin_size 
139+             self , samples , wire_order , shot_range = shot_range , bin_size = bin_size ,  dtype = self . _dtype 
132140        )
133141
134-     def  process_counts (self , counts : dict , wire_order : Wires ) :
142+     def  process_counts (self , counts : dict , wire_order : WiresLike )  ->   np . ndarray :
135143        samples  =  []
136144        mapped_counts  =  self ._map_counts (counts , wire_order )
137145        for  outcome , count  in  mapped_counts .items ():
@@ -143,11 +151,11 @@ def process_counts(self, counts: dict, wire_order: Wires):
143151
144152        return  np .array (samples )
145153
146-     def  _map_counts (self , counts_to_map , wire_order ) ->  dict :
154+     def  _map_counts (self , counts_to_map :  dict , wire_order :  WiresLike ) ->  dict :
147155        """ 
148156        Args: 
149-             counts_to_map: Dictionary where key is binary representation of the outcome and value is its count 
150-             wire_order: Order of wires to which counts_to_map should be ordered in 
157+             counts_to_map (dict) : Dictionary where key is binary representation of the outcome and value is its count 
158+             wire_order (WiresLike) : Order of wires to which counts_to_map should be ordered in 
151159
152160        Returns: 
153161            Dictionary where counts_to_map has been reordered according to wire_order 
@@ -156,7 +164,7 @@ def _map_counts(self, counts_to_map, wire_order) -> dict:
156164            helper_counts  =  CountsMP (wires = self .wires , all_outcomes = False )
157165        return  helper_counts .process_counts (counts_to_map , wire_order )
158166
159-     def  _compute_outcome_sample (self , outcome ) ->  list :
167+     def  _compute_outcome_sample (self , outcome :  str ) ->  list :
160168        """ 
161169        Args: 
162170            outcome (str): The binary string representation of the measurement outcome. 
@@ -174,11 +182,12 @@ def _compute_outcome_sample(self, outcome) -> list:
174182
175183def  sample (
176184    op : Operator  |  MeasurementValue  |  Sequence [MeasurementValue ] |  None  =  None ,
177-     wires = None ,
185+     wires : WiresLike  =  None ,
186+     dtype = None ,
178187) ->  SampleMP :
179188    r"""Sample from the supplied observable, with the number of shots 
180189    determined from QNode, 
181-     returning raw samples. If no observable is provided then basis state samples are returned 
190+     returning raw samples. If no observable is provided,  then basis state samples are returned 
182191    directly from the device. 
183192
184193    Note that the output shape of this measurement process depends on the shots 
@@ -189,6 +198,7 @@ def sample(
189198            for mid-circuit measurements, ``op`` should be a ``MeasurementValue``. 
190199        wires (Sequence[int] or int or None): the wires we wish to sample from; ONLY set wires if 
191200            op is ``None``. 
201+         dtype: The dtype of the samples returned by this measurement process. 
192202
193203    Returns: 
194204        SampleMP: Measurement process instance 
@@ -292,8 +302,8 @@ def circuit(x):
292302    array([ 1.,  1.,  1., -1.]) 
293303
294304    If no observable is provided, then the raw basis state samples obtained 
295-     from device are returned (e.g., for a qubit device, samples from the 
296-     computational device  are returned). In this case, ``wires`` can be specified 
305+     from the  device are returned (e.g., for a qubit device, samples from the 
306+     computational basis  are returned). In this case, ``wires`` can be specified 
297307    so that sample results only include measurement results of the qubits of interest. 
298308
299309    .. code-block:: python3 
@@ -317,5 +327,54 @@ def circuit(x):
317327           [1, 1], 
318328           [0, 0]]) 
319329
330+     .. details:: 
331+             :title: Setting the precision of the samples 
332+ 
333+             The ``dtype`` argument can be used to set the type and precision of the samples returned by this measurement process 
334+             when the ``op`` argument does not contain mid-circuit measurements. Otherwise, the ``dtype`` argument is ignored. 
335+ 
336+             By default, the samples will be returned as floating point numbers if an observable is provided, 
337+             and as integers if no observable is provided. The ``dtype`` argument can be used to specify further details, 
338+             and set the precision to any valid interface-like dtype, e.g. ``'float32'``, ``'int8'``, ``'uint16'``, etc. 
339+ 
340+             We show two examples below using the JAX and PyTorch interfaces. 
341+             This argument is compatible with all interfaces currently supported by PennyLane. 
342+ 
343+             **Example:** 
344+ 
345+             .. code-block:: python3 
346+ 
347+                 @qml.set_shots(1000000) 
348+                 @qml.qnode(qml.device("default.qubit", wires=1), interface="jax") 
349+                 def circuit(): 
350+                     qml.Hadamard(0) 
351+                     return qml.sample(dtype="int8") 
352+ 
353+             Executing this QNode, we get: 
354+ 
355+             >>> samples = circuit() 
356+             >>> samples.dtype 
357+             dtype('int8') 
358+             >>> type(samples) 
359+             jaxlib._jax.ArrayImpl 
360+ 
361+             If an observable is provided, the samples will be floating point numbers: 
362+ 
363+             .. code-block:: python3 
364+ 
365+                 @qml.set_shots(1000000) 
366+                 @qml.qnode(qml.device("default.qubit", wires=1), interface="torch") 
367+                 def circuit(): 
368+                     qml.Hadamard(0) 
369+                     return qml.sample(qml.Z(0), dtype="float32") 
370+ 
371+             Executing this QNode, we get: 
372+ 
373+             >>> samples = circuit() 
374+             >>> samples.dtype 
375+             torch.float32 
376+             >>> type(samples) 
377+             torch.Tensor 
378+ 
320379    """ 
321-     return  SampleMP (obs = op , wires = None  if  wires  is  None  else  Wires (wires ))
380+     return  SampleMP (obs = op , wires = None  if  wires  is  None  else  Wires (wires ),  dtype = dtype )
0 commit comments