Skip to content

AbsFunctions added #1976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e6f2e47
AbsFunctions added
fmwatson Nov 6, 2024
12f8010
First pass at tests and documentation
MargaretDuff Feb 21, 2025
252cec3
Merge branch 'master' into function_of_abs
MargaretDuff Feb 21, 2025
ea99da8
Actually add the file
MargaretDuff Feb 21, 2025
2beefc2
Merge branch 'function_of_abs' of github.com:fmwatson/CIL into pr/fmw…
MargaretDuff Feb 21, 2025
78daaab
Check documentation again
MargaretDuff Feb 21, 2025
55b8cdd
Fix test
MargaretDuff Feb 21, 2025
5c82c66
Improve documentation
MargaretDuff Feb 21, 2025
d9150ea
Merge branch 'master' into function_of_abs
MargaretDuff Feb 21, 2025
620a602
Testing for out in place
MargaretDuff Feb 21, 2025
5c0c893
Merge branch 'function_of_abs' of github.com:fmwatson/CIL into pr/fmw…
MargaretDuff Feb 21, 2025
1eb7d8b
Documentation tweaks
MargaretDuff Feb 26, 2025
d82777d
Update from Edo's review
MargaretDuff Feb 26, 2025
7b4f08a
Proximal map --> proximal operator
MargaretDuff Feb 26, 2025
bd16470
Merge branch 'master' into function_of_abs
MargaretDuff Mar 4, 2025
0572db1
Merge branch 'master' into function_of_abs
MargaretDuff Mar 5, 2025
735a5b6
Updates from Laura's review
MargaretDuff Mar 5, 2025
ac5d4a1
Reference formatting
MargaretDuff Mar 5, 2025
7869881
Merge branch 'master' into function_of_abs
MargaretDuff Mar 5, 2025
c421fb5
Parameters section formatting
MargaretDuff Mar 6, 2025
848ef9e
Added single/double precision options
MargaretDuff Mar 11, 2025
8c5df65
Fix failing unit test
MargaretDuff Mar 11, 2025
1d10cfc
set default precision to double, as necessary for SAR which is curren…
fmwatson Mar 12, 2025
bb44155
added output type hints
fmwatson Mar 12, 2025
4fa481a
convex_conjugate returns 0. when mathematically undefined to ensure c…
fmwatson Mar 12, 2025
d64316f
made decorators separate from class
fmwatson Mar 12, 2025
7b5e577
removed unnecessary comments/old code
fmwatson Mar 12, 2025
8446924
Updates from discussion in dev meeting - still failing out tests
MargaretDuff Mar 13, 2025
1c6a034
Get precision from inputted x
MargaretDuff Mar 20, 2025
c5f5320
Merge branch 'master' into function_of_abs
MargaretDuff Mar 20, 2025
f687556
Merge branch 'master' into function_of_abs
MargaretDuff Apr 29, 2025
9b91a7f
Apply suggestion from @hrobarts
MargaretDuff Jul 24, 2025
b7e17bb
Merge branch 'master' into function_of_abs
lauramurgatroyd Jul 24, 2025
95c2dc1
Update NOTICE.txt
hrobarts Jul 24, 2025
7a622e1
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 24, 2025
00afe5b
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 24, 2025
77cc514
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 25, 2025
e1eeee1
Apply suggestions from code review
hrobarts Jul 25, 2025
25187a4
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 25, 2025
43a9664
Changelog
hrobarts Jul 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NOTICE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ Sam Porter (2024) - 5
Joshua Hellier (2024) - 3
Nicholas Whyatt (2024) - 1
Rasmia Kulan (2024) - 1
Francis M Watson (2024) - 3
Emmanuel Ferdman (2025) - 14


CIL Advisory Board:
Llion Evans - 9
William Lionheart - 3
Expand Down
227 changes: 227 additions & 0 deletions Wrappers/Python/cil/optimisation/functions/AbsFunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright 2024 United Kingdom Research and Innovation
# Copyright 2024 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
#
# This work has been supported by the Royal Academy of Engineering and the
# Office of the Chief Science Adviser for National Security under the UK
# Intelligence Community Postdoctoral Research Fellowship programme.
#
# Francis M Watson, University of Manchester 2024
#


import numpy as np
from cil.optimisation.functions import Function
from cil.framework import DataContainer
from typing import Optional
import warnings
import logging

log = logging.getLogger(__name__)

class FunctionOfAbs(Function):
r'''A function which acts on the absolute value of the complex-valued input,

.. math:: G(z) = H(abs(z))

This function is initialised with another CIL function, :math:`H` in the above formula. When this function is called, first the absolute value of the input is taken, and then the input is passed to the provided function.

Included in this class is the proximal map for FunctionOfAbs. From this, the proximal conjugate is also available from the parent CIL Function class, which is valid for this function.
In the case that :math:`H` is lower semi-continuous, convex, non-decreasing and finite at the origin, (and thus `assume_lower_semi` is set to `True` in the `init`) the convex conjugate is also defined.
The gradient is not defined for this function.



Parameters
----------
function : Function
Function acting on a real input, :math:`H` in the above formula.
assume_lower_semi : bool, default False
If True, assume that the function is lower semi-continuous, convex, non-decreasing and finite at the origin.
This allows the convex conjugate to be calculated as the monotone conjugate, which is less than or equal to the convex conjugate.
If False, the convex conjugate returned as 0. This is to ensure compatibility with Algorithms such as PDHG.
precision : str, default 'double'
Precision of the calculation, 'single' or 'double'
Some complex-valued imaging problems involve high dynamic range and/or require fine phase accuracy, necessitating the use of double precision (default).
For example, in synthetic aperture radar imagery 100dB+ dynamic range may be encountered, which is more than single precision allows.
In other cases use of single precision will reduce memory reservation and may improve performance, depending on the compute architecture.




Reference
---------
For further details see https://doi.org/10.48550/arXiv.2410.22161

'''

def __init__(self, function: Function, assume_lower_semi: bool=False, precision: str='double'):
self._function = function
self._lower_semi = assume_lower_semi

if precision=='single' or precision=='double':
self.precision = precision
else:
raise ValueError('Precision must be `single` or `double`')

super().__init__(L=function.L)

def __call__(self, x: DataContainer) -> float:
call_abs = (_take_abs_input(self.precision))(self._function.__call__)
return call_abs(self._function, x)

def proximal(self, x: DataContainer, tau: float, out: Optional[DataContainer]=None) -> DataContainer:
r'''Returns the proximal map of function :math:`\tau G` evaluated at x

.. math:: \text{prox}_{\tau G}(x) = \underset{z}{\text{argmin}} \frac{1}{2}\|z - x\|^{2} + \tau G(z)

This is accomplished by calculating a bounded proximal map and making a change of phase,
:math:`prox_G(z) = prox^+_H(r) \circ \Phi` where :math:`z = r \circ \Phi`, :math:`r = abs(z)`, :math:`\Phi = \exp(i angl(z))`,
and :math:`\circ` is element-wise product. Also define :math:`prox^+` to be the proximal map of :math:`H` in which the minimisation carried out over the positive orthant.


Parameters
----------
x : DataContainer
The input to the function
tau: scalar
The scalar multiplying the function in the proximal map
out: return DataContainer, if None a new DataContainer is returned, default None.
DataContainer to store the result of the proximal map


Returns
-------
DataContainer, the proximal map of the function at x with scalar :math:`\tau`.

'''
prox_abs = _abs_and_project(self.precision)(self._function.proximal)
return prox_abs(self._function, x, tau=tau, out=out)

def convex_conjugate(self, x: DataContainer) -> float:
r'''
Evaluation of the function G* at x, where G* is the convex conjugate of function G,

.. math:: G^{*}(x^{*}) = \underset{x}{\sup} \langle x^{*}, x \rangle - G(x)

If :math:`H` is lower semi-continuous, convex, non-decreasing
finite at the origin, then :math:`G^*(z*) = H^+(|z*|)`, where the monotone conjugate :math:`g^+` is

.. math:: H^+(z^*) =sup {(z, z^*) - H(z) : z >= O}

The monotone conjugate will therefore be less than or equal to the convex conjugate,
since it is taken over a smaller set. It is not available directly, but may coincide with
the convex conjugate, which is therefore the best estimate we have. This is only valid for
real x. In other cases, a general convex conjugate is not available or defined.


For reference see: Convex Analysis, R. Tyrrell Rocakfellar, pp110-111.


Parameters
----------
x : DataContainer
The input to the function

Returns
-------
float:

'''

if self._lower_semi:
conv_abs = (_take_abs_input(self.precision))(self._function.convex_conjugate)
return conv_abs(self._function, x)
else:
warnings.warn('Convex conjugate is not properly for this function, returning 0 for compatibility with optimisation algorithms')
return 0.0

def gradient(self, x):
'''Gradient of the function at x is not defined for this function.
'''
raise NotImplementedError('Gradient not available for this function')

def _take_abs_input(precision='double'):
def _take_abs_input_inner(func):
'''Decorator for function to act on abs of input of a method'''

def _take_abs_decorator(self, x: DataContainer, *args, **kwargs):
if precision == 'single':
real_dtype = np.float32
elif precision == 'double':
real_dtype = np.float64
else:
raise ValueError('Precision must be `single` or `double`')
rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(0)
r.fill(np.abs(x.array).astype(real_dtype))
fval = func(r, *args, **kwargs)
return fval
return _take_abs_decorator
return _take_abs_input_inner

def _abs_and_project(precision='double'):
def _abs_and_project_inner(func):
'''Decorator for function to act on abs of input,
with return being projected to the angle of the input.
Requires function return to have the same shape as input,
such as prox.'''


def _abs_project_decorator(self, x: DataContainer, *args, **kwargs):
if precision == 'single':
real_dtype = np.float32
complex_dtype = np.complex64
elif precision == 'double':
real_dtype = np.float64
complex_dtype = np.complex128
else:
raise ValueError('Precision must be `single` or `double`')
rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(None)
r.fill( np.abs(x.array).astype(real_dtype))
Phi = np.exp((1j*np.angle(x.array)))
out = kwargs.pop('out', None)

fvals = func(r, *args, **kwargs)

# Douglas-Rachford splitting to find solution in positive orthant
if np.any(fvals.array < 0):
log.info('AbsFunctions: projection to +ve orthant triggered')
cts = 0
y = r.copy()
while np.any(fvals.array < 0):
tmp = fvals.array - 0.5*y.array + 0.5*r.array
tmp[tmp < 0] = 0.
y.array += tmp - fvals.array
fvals = func(y, *args, **kwargs)
cts += 1
if cts > 10:
fvals.array[fvals.array < 0] = 0.
break

if out is not None:
out.array = fvals.array.astype(complex_dtype)*Phi

else:
out = x.geometry.allocate(None)
out.array = fvals.array.astype(complex_dtype)*Phi
return out
return _abs_project_decorator
return _abs_and_project_inner
1 change: 1 addition & 0 deletions Wrappers/Python/cil/optimisation/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@
from .SGFunction import SGFunction
from .SVRGFunction import SVRGFunction, LSVRGFunction
from .SAGFunction import SAGFunction, SAGAFunction
from .AbsFunction import FunctionOfAbs

72 changes: 71 additions & 1 deletion Wrappers/Python/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
L1Norm, MixedL21Norm, LeastSquares, \
SmoothMixedL21Norm, OperatorCompositionFunction,\
Rosenbrock, IndicatorBox, TotalVariation, ScaledFunction, SumFunction, SumScalarFunction, \
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, L1Sparsity
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, L1Sparsity, FunctionOfAbs

from cil.optimisation.functions import BlockFunction

Expand Down Expand Up @@ -2124,3 +2124,73 @@ def test_set_num_threads(self):
N = 10
ib.set_num_threads(N)
assert ib.num_threads == N


class MockFunction(Function):
"""Mock function to test FunctionOfAbs"""
def __init__(self, L=1.0):
super().__init__(L=L)

def __call__(self, x):
return x.array**2 # Simple squared function

def proximal(self, x, tau, out=None):
if out is None:
out = x.geometry.allocate(None)
out.array = x.array / (1 + tau) # Simple proximal operator
return out

def convex_conjugate(self, x):
return 0.5 * x.array**2 # Quadratic conjugate function


class TestFunctionOfAbs(unittest.TestCase):
def setUp(self):
self.data = VectorData(np.array([1+2j, -3+4j, -5-6j]))
self.mock_function = MockFunction()
self.abs_function = FunctionOfAbs(self.mock_function)
self.abs_function_double = FunctionOfAbs(self.mock_function, precision='double')


def test_initialization(self):
self.assertIsInstance(self.abs_function._function, MockFunction)
self.assertFalse(self.abs_function._lower_semi)

self.assertEqual(self.abs_function.real_dtype, np.float32)
self.assertEqual(self.abs_function_double.real_dtype, np.float64)

def test_call(self):
result = self.abs_function(self.data)
expected = np.abs(self.data.array)**2
np.testing.assert_array_almost_equal(result, expected, decimal=4)

result = self.abs_function_double(self.data)
np.testing.assert_array_almost_equal(result, expected, decimal=6)


def test_proximal(self):
tau = 0.5
result = self.abs_function.proximal(self.data, tau)
expected = (np.abs(self.data.array) / (1 + tau)) * np.exp(1j * np.angle(self.data.array))
np.testing.assert_array_almost_equal(result.array, expected)

result = self.abs_function_double.proximal(self.data, tau)
np.testing.assert_array_almost_equal(result.array, expected, decimal=4)

def test_convex_conjugate_lower_semi(self):
self.abs_function._lower_semi = True
result = self.abs_function.convex_conjugate(self.data)
expected = 0.5 * np.abs(self.data.array) ** 2
np.testing.assert_array_almost_equal(result, expected, decimal=4)

self.abs_function_double._lower_semi = True
result = self.abs_function_double.convex_conjugate(self.data)
np.testing.assert_array_almost_equal(result, expected, decimal=6)


def test_convex_conjugate_not_implemented(self):
self.abs_function._lower_semi = False
with self.assertRaises(NotImplementedError):
self.abs_function.convex_conjugate(self.data)


9 changes: 5 additions & 4 deletions Wrappers/Python/test/test_out_in_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
L1Norm, L2NormSquared, MixedL21Norm, LeastSquares, \
SmoothMixedL21Norm, OperatorCompositionFunction, \
IndicatorBox, TotalVariation, SumFunction, SumScalarFunction, \
WeightedL2NormSquared, MixedL11Norm, ZeroFunction
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, FunctionOfAbs

from cil.processors import AbsorptionTransmissionConverter, Binner, CentreOfRotationCorrector, MaskGenerator, Masker, Normaliser, Padder, \
RingRemover, Slicer, TransmissionAbsorptionConverter, PaganinProcessor, FluxNormaliser
Expand Down Expand Up @@ -127,9 +127,10 @@ def setUp(self):
(MixedL11Norm(), bg, True, True, False),
(BlockFunction(L1Norm(),L2NormSquared()), bg, True, True, False),
(BlockFunction(L2NormSquared(),L2NormSquared()), bg, True, True, True),
(L1Sparsity(WaveletOperator(ig)), ig, True, True, False)


(L1Sparsity(WaveletOperator(ig)), ig, True, True, False),
(FunctionOfAbs(TotalVariation(backend='numpy', warm_start= False), assume_lower_semi=True), ig , True, True, False),
(FunctionOfAbs(TotalVariation(backend='numpy', warm_start=False), assume_lower_semi=True, precision='double'), ig , True, True, False),

]

np.random.seed(5)
Expand Down
8 changes: 8 additions & 0 deletions docs/source/optimisation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,14 @@ Total variation
:members:
:inherited-members:

Function of Absolute Value
--------------------------

.. autoclass:: cil.optimisation.functions.FunctionOfAbs
:members:
:inherited-members:


Approximate Gradient base class
--------------------------------

Expand Down