Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e6f2e47
AbsFunctions added
fmwatson Nov 6, 2024
12f8010
First pass at tests and documentation
MargaretDuff Feb 21, 2025
252cec3
Merge branch 'master' into function_of_abs
MargaretDuff Feb 21, 2025
ea99da8
Actually add the file
MargaretDuff Feb 21, 2025
2beefc2
Merge branch 'function_of_abs' of github.com:fmwatson/CIL into pr/fmw…
MargaretDuff Feb 21, 2025
78daaab
Check documentation again
MargaretDuff Feb 21, 2025
55b8cdd
Fix test
MargaretDuff Feb 21, 2025
5c82c66
Improve documentation
MargaretDuff Feb 21, 2025
d9150ea
Merge branch 'master' into function_of_abs
MargaretDuff Feb 21, 2025
620a602
Testing for out in place
MargaretDuff Feb 21, 2025
5c0c893
Merge branch 'function_of_abs' of github.com:fmwatson/CIL into pr/fmw…
MargaretDuff Feb 21, 2025
1eb7d8b
Documentation tweaks
MargaretDuff Feb 26, 2025
d82777d
Update from Edo's review
MargaretDuff Feb 26, 2025
7b4f08a
Proximal map --> proximal operator
MargaretDuff Feb 26, 2025
bd16470
Merge branch 'master' into function_of_abs
MargaretDuff Mar 4, 2025
0572db1
Merge branch 'master' into function_of_abs
MargaretDuff Mar 5, 2025
735a5b6
Updates from Laura's review
MargaretDuff Mar 5, 2025
ac5d4a1
Reference formatting
MargaretDuff Mar 5, 2025
7869881
Merge branch 'master' into function_of_abs
MargaretDuff Mar 5, 2025
c421fb5
Parameters section formatting
MargaretDuff Mar 6, 2025
848ef9e
Added single/double precision options
MargaretDuff Mar 11, 2025
8c5df65
Fix failing unit test
MargaretDuff Mar 11, 2025
1d10cfc
set default precision to double, as necessary for SAR which is curren…
fmwatson Mar 12, 2025
bb44155
added output type hints
fmwatson Mar 12, 2025
4fa481a
convex_conjugate returns 0. when mathematically undefined to ensure c…
fmwatson Mar 12, 2025
d64316f
made decorators separate from class
fmwatson Mar 12, 2025
7b5e577
removed unnecessary comments/old code
fmwatson Mar 12, 2025
8446924
Updates from discussion in dev meeting - still failing out tests
MargaretDuff Mar 13, 2025
1c6a034
Get precision from inputted x
MargaretDuff Mar 20, 2025
c5f5320
Merge branch 'master' into function_of_abs
MargaretDuff Mar 20, 2025
f687556
Merge branch 'master' into function_of_abs
MargaretDuff Apr 29, 2025
9b91a7f
Apply suggestion from @hrobarts
MargaretDuff Jul 24, 2025
b7e17bb
Merge branch 'master' into function_of_abs
lauramurgatroyd Jul 24, 2025
95c2dc1
Update NOTICE.txt
hrobarts Jul 24, 2025
7a622e1
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 24, 2025
00afe5b
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 24, 2025
77cc514
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 25, 2025
e1eeee1
Apply suggestions from code review
hrobarts Jul 25, 2025
25187a4
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 25, 2025
43a9664
Changelog
hrobarts Jul 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 118 additions & 82 deletions Wrappers/Python/cil/optimisation/functions/AbsFunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,47 +25,53 @@


import numpy as np
from cil.optimisation.functions import Function, TotalVariation
from cil.optimisation.functions import Function
from cil.framework import DataContainer
from typing import Optional
import warnings
import logging

log = logging.getLogger(__name__)

class FunctionOfAbs(Function):
r'''A function which acts on the absolute value of the complex-valued input,

.. math:: G(z) = H(abs(z))

This function is initialised with another CIL function, :math:`H` in the above formula. When this function is called, first the absolute value of the input is taken, and then the input is passed to the provided function, :math:.

This function defines the proximal map and convex conjugate, in the case that H is lower semi-continuous, convex, non-decreasing and finite at the origin, but not the gradient which is not implemented. The calculation of the proximal conjugate from the parent CIL Function class is valid for this function.
This function is initialised with another CIL function, :math:`H` in the above formula. When this function is called, first the absolute value of the input is taken, and then the input is passed to the provided function.

Included in this class is the proximal map for FunctionOfAbs. From this, the proximal conjugate is also available from the parent CIL Function class, which is valid for this function.
In the case that :math:`H` is lower semi-continuous, convex, non-decreasing and finite at the origin, (and thus `assume_lower_semi` is set to `True` in the `init`) the convex conjugate is also defined.
The gradient is not defined for this function.



Parameters
----------
function : Function
Function acting on a real-valued input
assume_lower_semi : bool
Function acting on a real input, :math:`H` in the above formula.
assume_lower_semi : bool, default False
If True, assume that the function is lower semi-continuous, convex, non-decreasing and finite at the origin.
This allows the convex conjugate to be calculated as the monotone conjugate, which is less than or equal to the convex conjugate.
If False, the convex conjugate is not implemented. Default is False.

If False, the convex conjugate returned as 0. This is to ensure compatibility with Algorithms such as PDHG.

Reference
---------
For further details see https://doi.org/10.48550/arXiv.2410.22161

'''

def __init__(self, function, assume_lower_semi=False):
def __init__(self, function: Function, assume_lower_semi: bool=False):
self._function = function
self._lower_semi = assume_lower_semi

super().__init__(L=function.L)

def __call__(self, x):
call_abs = self._take_abs_input(self._function.__call__)
def __call__(self, x: DataContainer) -> float:
call_abs = _take_abs_input(self._function.__call__)
return call_abs(self._function, x)

def proximal(self, x, tau, out=None):
def proximal(self, x: DataContainer, tau: float, out: Optional[DataContainer]=None) -> DataContainer:
r'''Returns the proximal map of function :math:`\tau G` evaluated at x

.. math:: \text{prox}_{\tau G}(x) = \underset{z}{\text{argmin}} \frac{1}{2}\|z - x\|^{2} + \tau G(z)
Expand All @@ -78,20 +84,22 @@ def proximal(self, x, tau, out=None):
Parameters
----------
x : DataContainer

The input to the function
tau: scalar

The scalar multiplying the function in the proximal map
out: return DataContainer, if None a new DataContainer is returned, default None.
DataContainer to store the result of the proximal map


Returns
-------
DataContainer, the proximal map of the function at x with scalar :math:`\tau`.

'''
prox_abs = self._abs_and_project(self._function.proximal)
prox_abs = _abs_and_project(self._function.proximal)
return prox_abs(self._function, x, tau=tau, out=out)

def convex_conjugate(self, x):
def convex_conjugate(self, x: DataContainer) -> float:
r'''
Evaluation of the function G* at x, where G* is the convex conjugate of function G,

Expand All @@ -108,83 +116,111 @@ def convex_conjugate(self, x):
real x. In other cases, a general convex conjugate is not available or defined.


For reference see: Convex Analysis, R. Tyrrell Rocakfellar, pp110-111.


Parameters
----------
x : DataContainer
The input to the function

Returns
-------
The value of the convex conjugate of the function at x.

Reference
---------
Convex Analysis, R. Tyrrell Rocakfellar, pp110-111

float:

'''

if self._lower_semi:
conv_abs = self._take_abs_input(self._function.convex_conjugate)
conv_abs = _take_abs_input(self._function.convex_conjugate)
return conv_abs(self._function, x)
else:
raise NotImplementedError(
'Convex conjugate not available for this function. If you are sure your function is lower semi-continuous, convex, non-decreasing and finite at the origin, set `assume_lower_semi=True`')

def _take_abs_input(self, func):
'''decorator for function to act on abs of input of a method'''

def _take_abs_decorator(self, x, *args, **kwargs):
rgeo = x.geometry.copy()
rgeo.dtype = np.float64
r = rgeo.allocate(0)
r.array = np.abs(x.array).astype(np.float64)
# func(self, r, *args, **kwargs) for the abstract class implementation
fval = func(r, *args, **kwargs)
return fval
return _take_abs_decorator

def _abs_and_project(self, func):
'''decorator for function to act on abs of input,
with return being projected to the angle of the input.
Requires function return to have the same shape as input,
such as prox.'''

def _abs_project_decorator(self, x, *args, **kwargs):
rgeo = x.geometry.copy()
rgeo.dtype = np.float64
r = rgeo.allocate(0)
r.array = np.abs(x.array).astype(np.float64)
Phi = np.exp(1j*np.angle(x.array))
out = kwargs.get('out', None)
if out is not None:
del kwargs['out']
# func(self, r, *args, **kwargs) for the abstract class implementation
fvals = func(r, *args, **kwargs)

# Douglas-Rachford splitting to find solution in positive orthant
if np.any(fvals.array < 0):
print('AbsFunctions: projection to +ve orthant triggered')
cts = 0
y = r.copy()
while np.any(fvals.array < 0):
tmp = fvals.array - 0.5*y.array + 0.5*r.array
tmp[tmp < 0] = 0.
y.array += tmp - fvals.array
fvals = func(y, *args, **kwargs)
cts += 1
if cts > 10:
fvals.array[fvals.array < 0] = 0.
break

if out is not None:
out.array = fvals.array.astype(np.complex128)*Phi
return out
else:
out = x.geometry.allocate(0)
out.array = fvals.array.astype(np.complex128)*Phi
return out
return _abs_project_decorator

def gradient(self):
warnings.warn('Convex conjugate is not properly for this function, returning 0 for compatibility with optimisation algorithms')
return 0.0

def gradient(self, x):
'''Gradient of the function at x is not defined for this function.
'''
raise NotImplementedError('Gradient not available for this function')




def _take_abs_input(func):
'''Decorator for function to act on abs of input of a method'''

def _take_abs_decorator(self, x: DataContainer, *args, **kwargs):
real_dtype, _ = _get_real_complex_dtype(x)

rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(0)
r.fill(np.abs(x.as_array()).astype(real_dtype))
fval = func(r, *args, **kwargs)
return fval
return _take_abs_decorator


def _abs_and_project(func):
'''Decorator for function to act on abs of input,
with return being projected to the angle of the input.
Requires function return to have the same shape as input,
such as prox.'''


def _abs_project_decorator(self, x: DataContainer, *args, **kwargs):

real_dtype, complex_dtype = _get_real_complex_dtype(x)


rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(None)
r.fill(np.abs(x.as_array()).astype(real_dtype))
Phi = np.exp((1j*np.angle(x.array)))
out = kwargs.pop('out', None)

fvals = func(r, *args, **kwargs)
fvals_np = fvals.as_array()

# Douglas-Rachford splitting to find solution in positive orthant
if np.any(fvals_np < 0):
log.info('AbsFunctions: projection to +ve orthant triggered')
cts = 0
y = r.copy()
fvals_np = fvals.as_array()
while np.any(fvals_np < 0):
tmp = fvals_np - 0.5*y.as_array() + 0.5*r.as_array()
tmp[tmp < 0] = 0.
y += DataContainter(tmp, y.geometry) - fvals
fvals = func(y, *args, **kwargs)
cts += 1
if cts > 10:
fvals_np = fvals.as_array()
fvals_np[fvals_np < 0] = 0.
break

if out is None:
out_geom = x.geometry.copy()
out = out_geom.allocate(None)
if np.isreal(x.as_array()).all():
out.fill( np.real(fvals_np.astype(complex_dtype)*Phi))
else:
out.fill( fvals_np.astype(complex_dtype)*Phi)
return out
return _abs_project_decorator


def _get_real_complex_dtype(x: DataContainer):
'''An internal function to find the type of x and set the corresponding real and complex data types '''

x_dtype = x.as_array().dtype
if np.issubdtype(x_dtype, np.complexfloating):
complex_dtype = x_dtype
complex_example = np.array([1 + 1j], dtype=x_dtype)
real_dtype = np.real(complex_example).dtype
else:
real_dtype = x_dtype
complex_example = 1j*np.array([1], dtype=x_dtype)
complex_dtype = complex_example.dtype
return real_dtype, complex_dtype

74 changes: 61 additions & 13 deletions Wrappers/Python/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,35 +2146,83 @@ def convex_conjugate(self, x):

class TestFunctionOfAbs(unittest.TestCase):
def setUp(self):
self.data = VectorData(np.array([1+2j, -3+4j, -5-6j]))
self.data_complex64 = VectorData(np.array([1+2j, -3+4j, -5-6j], dtype=np.complex64))
self.data_complex128 = VectorData(np.array([1+2j, -3+4j, -5-6j], dtype=np.complex128))
self.data_real32 = VectorData(np.array([1, 2, 3], dtype=np.float32))
self.mock_function = MockFunction()
self.abs_function = FunctionOfAbs(self.mock_function)


def test_initialization(self):
self.assertIsInstance(self.abs_function._function, MockFunction)
self.assertFalse(self.abs_function._lower_semi)




def test_call(self):
result = self.abs_function(self.data)
expected = np.abs(self.data.array)**2
np.testing.assert_array_almost_equal(result, expected)
result = self.abs_function(self.data_complex64)
expected = np.abs(self.data_complex64.array)**2
np.testing.assert_array_almost_equal(result, expected, decimal=4)

result = self.abs_function(self.data_complex128)
expected = np.abs(self.data_complex128.array)**2
np.testing.assert_array_almost_equal(result, expected, decimal=6)

result = self.abs_function(self.data_real32)
expected = np.abs(self.data_real32.array)**2
np.testing.assert_array_almost_equal(result, expected, decimal=6)

def test_get_real_complex_dtype(self):
from cil.optimisation.functions.AbsFunction import _get_real_complex_dtype
real, compl = _get_real_complex_dtype(self.data_complex128)
self.assertEqual( real, np.float64)
self.assertEqual(compl, np.complex128)

real, compl = _get_real_complex_dtype(self.data_complex64)
self.assertEqual( real, np.float32)
self.assertEqual(compl, np.complex64)

real, compl = _get_real_complex_dtype(self.data_real32)
self.assertEqual( real, np.float32)
self.assertEqual(compl, np.complex64)




def test_proximal(self):
tau = 0.5
result = self.abs_function.proximal(self.data, tau)
expected = (np.abs(self.data.array) / (1 + tau)) * np.exp(1j * np.angle(self.data.array))
np.testing.assert_array_almost_equal(result.array, expected)
result = self.abs_function.proximal(self.data_complex128, tau)
expected = (np.abs(self.data_complex128.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_complex128.array))
np.testing.assert_array_almost_equal(result.array, expected, decimal=6)

result = self.abs_function.proximal(self.data_complex64, tau)
expected = (np.abs(self.data_complex64.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_complex64.array))
np.testing.assert_array_almost_equal(result.array, expected, decimal=4)

result = self.abs_function.proximal(self.data_real32, tau)
expected = (np.abs(self.data_real32.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_real32.array))
np.testing.assert_array_almost_equal(result.array, expected, decimal=4)



def test_convex_conjugate_lower_semi(self):
self.abs_function._lower_semi = True
result = self.abs_function.convex_conjugate(self.data)
expected = 0.5 * np.abs(self.data.array) ** 2
np.testing.assert_array_almost_equal(result, expected)
result = self.abs_function.convex_conjugate(self.data_complex128)
expected = 0.5 * np.abs(self.data_complex128.array) ** 2
np.testing.assert_array_almost_equal(result, expected, decimal=6)

result = self.abs_function.convex_conjugate(self.data_complex64)
expected = 0.5 * np.abs(self.data_complex64.array) ** 2
np.testing.assert_array_almost_equal(result, expected, decimal=4)

result = self.abs_function.convex_conjugate(self.data_real32)
expected = 0.5 * np.abs(self.data_real32.array) ** 2
np.testing.assert_array_almost_equal(result, expected, decimal=4)


def test_convex_conjugate_not_implemented(self):
self.abs_function._lower_semi = False
with self.assertRaises(NotImplementedError):
self.abs_function.convex_conjugate(self.data)

self.assertEqual(self.abs_function.convex_conjugate(self.data_real32), 0.)


8 changes: 6 additions & 2 deletions Wrappers/Python/test/test_out_in_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def setUp(self):
ag.set_panel(10)

ig = ag.get_ImageGeometry()

ig_complex = ig.copy()
ig_complex.dtype = np.complex128

scalar = 4

Expand Down Expand Up @@ -128,11 +131,12 @@ def setUp(self):
(BlockFunction(L1Norm(),L2NormSquared()), bg, True, True, False),
(BlockFunction(L2NormSquared(),L2NormSquared()), bg, True, True, True),
(L1Sparsity(WaveletOperator(ig)), ig, True, True, False),
(FunctionOfAbs(TotalVariation(backend='cpu'), assume_lower_semi=True), ig , True, True, False)
(FunctionOfAbs(TotalVariation(backend='numpy', warm_start= False), assume_lower_semi=True), ig , True, True, False),
(FunctionOfAbs(TotalVariation(backend='numpy', warm_start= False), assume_lower_semi=True), ig_complex , True, True, False)
]

np.random.seed(5)
self.data_arrays=[np.random.normal(0,1, (10,10)).astype(np.float32), np.array(range(0,65500, 655), dtype=np.uint16).reshape((10,10)), np.random.uniform(-0.1,1,(10,10)).astype(np.float32)]
self.data_arrays=[np.random.normal(0,1, (10,10)).astype(np.float32), np.array(range(0,65500, 655), dtype=np.uint16).reshape((10,10)), np.random.uniform(-0.1,1,(10,10)).astype(np.float32) ]

def get_result(self, function, method, x, *args):
try:
Expand Down
Loading