Skip to content

AbsFunctions added #1976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e6f2e47
AbsFunctions added
fmwatson Nov 6, 2024
12f8010
First pass at tests and documentation
MargaretDuff Feb 21, 2025
252cec3
Merge branch 'master' into function_of_abs
MargaretDuff Feb 21, 2025
ea99da8
Actually add the file
MargaretDuff Feb 21, 2025
2beefc2
Merge branch 'function_of_abs' of github.com:fmwatson/CIL into pr/fmw…
MargaretDuff Feb 21, 2025
78daaab
Check documentation again
MargaretDuff Feb 21, 2025
55b8cdd
Fix test
MargaretDuff Feb 21, 2025
5c82c66
Improve documentation
MargaretDuff Feb 21, 2025
d9150ea
Merge branch 'master' into function_of_abs
MargaretDuff Feb 21, 2025
620a602
Testing for out in place
MargaretDuff Feb 21, 2025
5c0c893
Merge branch 'function_of_abs' of github.com:fmwatson/CIL into pr/fmw…
MargaretDuff Feb 21, 2025
1eb7d8b
Documentation tweaks
MargaretDuff Feb 26, 2025
d82777d
Update from Edo's review
MargaretDuff Feb 26, 2025
7b4f08a
Proximal map --> proximal operator
MargaretDuff Feb 26, 2025
bd16470
Merge branch 'master' into function_of_abs
MargaretDuff Mar 4, 2025
0572db1
Merge branch 'master' into function_of_abs
MargaretDuff Mar 5, 2025
735a5b6
Updates from Laura's review
MargaretDuff Mar 5, 2025
ac5d4a1
Reference formatting
MargaretDuff Mar 5, 2025
7869881
Merge branch 'master' into function_of_abs
MargaretDuff Mar 5, 2025
c421fb5
Parameters section formatting
MargaretDuff Mar 6, 2025
848ef9e
Added single/double precision options
MargaretDuff Mar 11, 2025
8c5df65
Fix failing unit test
MargaretDuff Mar 11, 2025
1d10cfc
set default precision to double, as necessary for SAR which is curren…
fmwatson Mar 12, 2025
bb44155
added output type hints
fmwatson Mar 12, 2025
4fa481a
convex_conjugate returns 0. when mathematically undefined to ensure c…
fmwatson Mar 12, 2025
d64316f
made decorators separate from class
fmwatson Mar 12, 2025
7b5e577
removed unnecessary comments/old code
fmwatson Mar 12, 2025
8446924
Updates from discussion in dev meeting - still failing out tests
MargaretDuff Mar 13, 2025
1c6a034
Get precision from inputted x
MargaretDuff Mar 20, 2025
c5f5320
Merge branch 'master' into function_of_abs
MargaretDuff Mar 20, 2025
f687556
Merge branch 'master' into function_of_abs
MargaretDuff Apr 29, 2025
9b91a7f
Apply suggestion from @hrobarts
MargaretDuff Jul 24, 2025
b7e17bb
Merge branch 'master' into function_of_abs
lauramurgatroyd Jul 24, 2025
95c2dc1
Update NOTICE.txt
hrobarts Jul 24, 2025
7a622e1
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 24, 2025
00afe5b
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 24, 2025
77cc514
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 25, 2025
e1eeee1
Apply suggestions from code review
hrobarts Jul 25, 2025
25187a4
Update Wrappers/Python/cil/optimisation/functions/AbsFunction.py
hrobarts Jul 25, 2025
43a9664
Changelog
hrobarts Jul 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NOTICE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Emmanuel Ferdman (2025) - 14
Mariam Demir (2025) - 1
Satwik Pani (2025) - 15
Hok Shing Wong (2025) - 7
Francis M Watson (2025) - 3

CIL Advisory Board:
Llion Evans - 9
Expand Down
222 changes: 222 additions & 0 deletions Wrappers/Python/cil/optimisation/functions/AbsFunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright 2024 United Kingdom Research and Innovation
# Copyright 2024 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
#
# This work has been supported by the Royal Academy of Engineering and the
# Office of the Chief Science Adviser for National Security under the UK
# Intelligence Community Postdoctoral Research Fellowship programme.
#
# Francis M Watson, University of Manchester 2024
#


import numpy as np
from cil.optimisation.functions import Function
from cil.framework import DataContainer
from typing import Optional
import warnings
import logging

log = logging.getLogger(__name__)

class FunctionOfAbs(Function):
r'''A function which acts on the absolute value of the complex-valued input,

.. math:: G(z) = H(abs(z))

This function is initialised with another CIL function, :math:`H` in the above formula. When this function is called, first the absolute value of the input is taken, and then the input is passed to the provided function.

Included in this class is the proximal map for FunctionOfAbs. From this, the proximal conjugate is also available from the parent CIL Function class, which is valid for this function.
In the case that :math:`H` is lower semi-continuous, convex, non-decreasing and finite at the origin, (and thus `assume_lower_semi` is set to `True` in the `init`) the convex conjugate is also defined.
The gradient is not defined for this function.

For reference: see https://doi.org/10.48550/arXiv.2410.22161

Parameters
----------
function : Function
Function acting on a real input, :math:`H` in the above formula.
assume_lower_semi : bool, default False
If True, assume that the function is lower semi-continuous, convex, non-decreasing and finite at the origin.
This allows the convex conjugate to be calculated as the monotone conjugate, which is less than or equal to the convex conjugate.
If False, the convex conjugate returned as 0. This is to ensure compatibility with Algorithms such as PDHG.

'''

def __init__(self, function: Function, assume_lower_semi: bool=False):
self._function = function
self._lower_semi = assume_lower_semi

super().__init__(L=function.L)

def __call__(self, x: DataContainer) -> float:
call_abs = _take_abs_input(self._function.__call__)
return call_abs(self._function, x)

def proximal(self, x: DataContainer, tau: float, out: Optional[DataContainer]=None) -> DataContainer:
r'''Returns the proximal map of function :math:`\tau G` evaluated at x

.. math:: \text{prox}_{\tau G}(x) = \underset{z}{\text{argmin}} \frac{1}{2}\|z - x\|^{2} + \tau G(z)

This is accomplished by calculating a bounded proximal map and making a change of phase,
:math:`prox_G(z) = prox^+_H(r) \circ \Phi` where :math:`z = r \circ \Phi`, :math:`r = abs(z)`, :math:`\Phi = \exp(i \angle(z))`,
and :math:`\circ` is element-wise product. Also define :math:`prox^+` to be the proximal map of :math:`H` in which the minimisation carried out over the positive orthant.


Parameters
----------
x : DataContainer
The input to the function
tau: scalar
The scalar multiplying the function in the proximal map
out: return DataContainer, if None a new DataContainer is returned, default None.
DataContainer to store the result of the proximal map


Returns
-------
DataContainer, the proximal map of the function at x with scalar :math:`\tau`.

'''
prox_abs = _abs_and_project(self._function.proximal)
return prox_abs(self._function, x, tau=tau, out=out)

def convex_conjugate(self, x: DataContainer) -> float:
r'''
Evaluation of the function G* at x, where G* is the convex conjugate of function G,

.. math:: G^{*}(x^{*}) = \underset{x}{\sup} \langle x^{*}, x \rangle - G(x)

If :math:`H` is lower semi-continuous, convex, non-decreasing
finite at the origin, then :math:`G^*(z*) = H^+(|z*|)`, where the monotone conjugate :math:`g^+` is

.. math:: H^+(z^*) =sup {(z, z^*) - H(z) : z >= O}

The monotone conjugate will therefore be less than or equal to the convex conjugate,
since it is taken over a smaller set. It is not available directly, but may coincide with
the convex conjugate, which is therefore the best estimate we have. This is only valid for
real x. In other cases, a general convex conjugate is not available or defined.


For reference see: Convex Analysis, R. Tyrrell Rocakfellar, pp110-111.


Parameters
----------
x : DataContainer
The input to the function

Returns
-------
float:
The value of the convex conjugate of the function at x.
'''

if self._lower_semi:
conv_abs = _take_abs_input(self._function.convex_conjugate)
return conv_abs(self._function, x)
else:
warnings.warn('Convex conjugate is not defined for this function, returning 0 for compatibility with optimisation algorithms')
return 0.0

def gradient(self, x):
'''Gradient of the function at x is not defined for this function.
'''
raise NotImplementedError('Gradient not available for this function')




def _take_abs_input(func):
'''Decorator for function to act on abs of input of a method'''

def _take_abs_decorator(self, x: DataContainer, *args, **kwargs):
real_dtype, _ = _get_real_complex_dtype(x)

rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(0)
r.fill(np.abs(x.as_array()).astype(real_dtype))
fval = func(r, *args, **kwargs)
return fval
return _take_abs_decorator


def _abs_and_project(func):
'''Decorator for function to act on abs of input,
with return being projected to the angle of the input.
Requires function return to have the same shape as input,
such as prox.'''


def _abs_project_decorator(self, x: DataContainer, *args, **kwargs):

real_dtype, complex_dtype = _get_real_complex_dtype(x)


rgeo = x.geometry.copy()
rgeo.dtype = real_dtype
r = rgeo.allocate(None)
r.fill(np.abs(x.as_array()).astype(real_dtype))
Phi = np.exp((1j*np.angle(x.array)))
out = kwargs.pop('out', None)

fvals = func(r, *args, **kwargs)
fvals_np = fvals.as_array()

# Douglas-Rachford splitting to find solution in positive orthant
if np.any(fvals_np < 0):
log.info('AbsFunctions: projection to +ve orthant triggered')
cts = 0
y = r.copy()
fvals_np = fvals.as_array()
while np.any(fvals_np < 0):
tmp = fvals_np - 0.5*y.as_array() + 0.5*r.as_array()
tmp[tmp < 0] = 0.
y += DataContainter(tmp, y.geometry) - fvals
fvals = func(y, *args, **kwargs)
cts += 1
if cts > 10:
fvals_np = fvals.as_array()
fvals_np[fvals_np < 0] = 0.
break

if out is None:
out_geom = x.geometry.copy()
out = out_geom.allocate(None)
if np.isreal(x.as_array()).all():
out.fill( np.real(fvals_np.astype(complex_dtype)*Phi))
else:
out.fill( fvals_np.astype(complex_dtype)*Phi)
return out
return _abs_project_decorator


def _get_real_complex_dtype(x: DataContainer):
'''An internal function to find the type of x and set the corresponding real and complex data types '''

x_dtype = x.as_array().dtype
if np.issubdtype(x_dtype, np.complexfloating):
complex_dtype = x_dtype
complex_example = np.array([1 + 1j], dtype=x_dtype)
real_dtype = np.real(complex_example).dtype
else:
real_dtype = x_dtype
complex_example = 1j*np.array([1], dtype=x_dtype)
complex_dtype = complex_example.dtype
return real_dtype, complex_dtype

1 change: 1 addition & 0 deletions Wrappers/Python/cil/optimisation/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@
from .SGFunction import SGFunction
from .SVRGFunction import SVRGFunction, LSVRGFunction
from .SAGFunction import SAGFunction, SAGAFunction
from .AbsFunction import FunctionOfAbs

104 changes: 103 additions & 1 deletion Wrappers/Python/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
L1Norm, MixedL21Norm, LeastSquares, \
SmoothMixedL21Norm, OperatorCompositionFunction,\
Rosenbrock, IndicatorBox, TotalVariation, ScaledFunction, SumFunction, SumScalarFunction, \
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, L1Sparsity
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, L1Sparsity, FunctionOfAbs

from cil.optimisation.functions import BlockFunction

Expand Down Expand Up @@ -2124,3 +2124,105 @@ def test_set_num_threads(self):
N = 10
ib.set_num_threads(N)
assert ib.num_threads == N


class MockFunction(Function):
"""Mock function to test FunctionOfAbs"""
def __init__(self, L=1.0):
super().__init__(L=L)

def __call__(self, x):
return x.array**2 # Simple squared function

def proximal(self, x, tau, out=None):
if out is None:
out = x.geometry.allocate(None)
out.array = x.array / (1 + tau) # Simple proximal operator
return out

def convex_conjugate(self, x):
return 0.5 * x.array**2 # Quadratic conjugate function


class TestFunctionOfAbs(unittest.TestCase):
def setUp(self):
self.data_complex64 = VectorData(np.array([1+2j, -3+4j, -5-6j], dtype=np.complex64))
self.data_complex128 = VectorData(np.array([1+2j, -3+4j, -5-6j], dtype=np.complex128))
self.data_real32 = VectorData(np.array([1, 2, 3], dtype=np.float32))
self.mock_function = MockFunction()
self.abs_function = FunctionOfAbs(self.mock_function)


def test_initialization(self):
self.assertIsInstance(self.abs_function._function, MockFunction)
self.assertFalse(self.abs_function._lower_semi)



def test_call(self):
result = self.abs_function(self.data_complex64)
expected = np.abs(self.data_complex64.array)**2
np.testing.assert_array_almost_equal(result, expected, decimal=4)

result = self.abs_function(self.data_complex128)
expected = np.abs(self.data_complex128.array)**2
np.testing.assert_array_almost_equal(result, expected, decimal=6)

result = self.abs_function(self.data_real32)
expected = np.abs(self.data_real32.array)**2
np.testing.assert_array_almost_equal(result, expected, decimal=6)

def test_get_real_complex_dtype(self):
from cil.optimisation.functions.AbsFunction import _get_real_complex_dtype
real, compl = _get_real_complex_dtype(self.data_complex128)
self.assertEqual( real, np.float64)
self.assertEqual(compl, np.complex128)

real, compl = _get_real_complex_dtype(self.data_complex64)
self.assertEqual( real, np.float32)
self.assertEqual(compl, np.complex64)

real, compl = _get_real_complex_dtype(self.data_real32)
self.assertEqual( real, np.float32)
self.assertEqual(compl, np.complex64)




def test_proximal(self):
tau = 0.5
result = self.abs_function.proximal(self.data_complex128, tau)
expected = (np.abs(self.data_complex128.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_complex128.array))
np.testing.assert_array_almost_equal(result.array, expected, decimal=6)

result = self.abs_function.proximal(self.data_complex64, tau)
expected = (np.abs(self.data_complex64.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_complex64.array))
np.testing.assert_array_almost_equal(result.array, expected, decimal=4)

result = self.abs_function.proximal(self.data_real32, tau)
expected = (np.abs(self.data_real32.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_real32.array))
np.testing.assert_array_almost_equal(result.array, expected, decimal=4)



def test_convex_conjugate_lower_semi(self):
self.abs_function._lower_semi = True
result = self.abs_function.convex_conjugate(self.data_complex128)
expected = 0.5 * np.abs(self.data_complex128.array) ** 2
np.testing.assert_array_almost_equal(result, expected, decimal=6)

result = self.abs_function.convex_conjugate(self.data_complex64)
expected = 0.5 * np.abs(self.data_complex64.array) ** 2
np.testing.assert_array_almost_equal(result, expected, decimal=4)

result = self.abs_function.convex_conjugate(self.data_real32)
expected = 0.5 * np.abs(self.data_real32.array) ** 2
np.testing.assert_array_almost_equal(result, expected, decimal=4)


def test_convex_conjugate_not_implemented(self):
self.abs_function._lower_semi = False

self.assertEqual(self.abs_function.convex_conjugate(self.data_real32), 0.)


13 changes: 8 additions & 5 deletions Wrappers/Python/test/test_out_in_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
L1Norm, L2NormSquared, MixedL21Norm, LeastSquares, \
SmoothMixedL21Norm, OperatorCompositionFunction, \
IndicatorBox, TotalVariation, SumFunction, SumScalarFunction, \
WeightedL2NormSquared, MixedL11Norm, ZeroFunction
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, FunctionOfAbs

from cil.processors import AbsorptionTransmissionConverter, Binner, CentreOfRotationCorrector, MaskGenerator, Masker, Normaliser, Padder, \
RingRemover, Slicer, TransmissionAbsorptionConverter, PaganinProcessor, FluxNormaliser
Expand Down Expand Up @@ -87,6 +87,9 @@ def setUp(self):
ag.set_panel(10)

ig = ag.get_ImageGeometry()

ig_complex = ig.copy()
ig_complex.dtype = np.complex128

scalar = 4

Expand Down Expand Up @@ -127,13 +130,13 @@ def setUp(self):
(MixedL11Norm(), bg, True, True, False),
(BlockFunction(L1Norm(),L2NormSquared()), bg, True, True, False),
(BlockFunction(L2NormSquared(),L2NormSquared()), bg, True, True, True),
(L1Sparsity(WaveletOperator(ig)), ig, True, True, False)


(L1Sparsity(WaveletOperator(ig)), ig, True, True, False),
(FunctionOfAbs(TotalVariation(backend='numpy', warm_start= False), assume_lower_semi=True), ig , True, True, False),
(FunctionOfAbs(TotalVariation(backend='numpy', warm_start= False), assume_lower_semi=True), ig_complex , True, True, False)
]

np.random.seed(5)
self.data_arrays=[np.random.normal(0,1, (10,10)).astype(np.float32), np.array(range(0,65500, 655), dtype=np.uint16).reshape((10,10)), np.random.uniform(-0.1,1,(10,10)).astype(np.float32)]
self.data_arrays=[np.random.normal(0,1, (10,10)).astype(np.float32), np.array(range(0,65500, 655), dtype=np.uint16).reshape((10,10)), np.random.uniform(-0.1,1,(10,10)).astype(np.float32) ]

def get_result(self, function, method, x, *args):
try:
Expand Down
Loading
Loading