Skip to content

Commit 981604e

Browse files
fmwatsonMargaretDufflauramurgatroydhrobarts
authored
AbsFunctions added (#1976)
* AbsFunctions added * Documentation * set default precision to double, as necessary for SAR which is currently the only use case for FunctionOfAbs returned type hints * added output type hints * convex_conjugate returns 0. when mathematically undefined to ensure compatibility with optimisation algorithms. precision default set to double to ensure sufficient accuracy with e.g. SAR imaging docstring updated to reflect these changes * made decorators separate from class * Get precision from inputted x * Update NOTICE.txt * Changelog --------- Signed-off-by: Margaret Duff <[email protected]> Signed-off-by: Hannah Robarts <[email protected]> Co-authored-by: Margaret Duff <[email protected]> Co-authored-by: Laura Murgatroyd <[email protected]> Co-authored-by: hrobarts <[email protected]>
1 parent b23202d commit 981604e

File tree

7 files changed

+345
-6
lines changed

7 files changed

+345
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
* XX.X
2+
- New features:
3+
- Added `FunctionOfAbs` class (#1976)
24
- Bug fixes:
35
- Fix deprecation warning for rtol and atol in GD (#2056)
46
- Removed the deprecated usage of run method in test_SIRF.py (#2070)

NOTICE.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ Emmanuel Ferdman (2025) - 14
7575
Mariam Demir (2025) - 1
7676
Satwik Pani (2025) - 15
7777
Hok Shing Wong (2025) - 7
78+
Francis M Watson (2025) - 3
7879

7980
CIL Advisory Board:
8081
Llion Evans - 9
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright 2024 United Kingdom Research and Innovation
2+
# Copyright 2024 The University of Manchester
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
# Authors:
17+
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
18+
#
19+
# This work has been supported by the Royal Academy of Engineering and the
20+
# Office of the Chief Science Adviser for National Security under the UK
21+
# Intelligence Community Postdoctoral Research Fellowship programme.
22+
#
23+
# Francis M Watson, University of Manchester 2024
24+
#
25+
26+
27+
import numpy as np
28+
from cil.optimisation.functions import Function
29+
from cil.framework import DataContainer
30+
from typing import Optional
31+
import warnings
32+
import logging
33+
34+
log = logging.getLogger(__name__)
35+
36+
class FunctionOfAbs(Function):
37+
r'''A function which acts on the absolute value of the complex-valued input,
38+
39+
.. math:: G(z) = H(abs(z))
40+
41+
This function is initialised with another CIL function, :math:`H` in the above formula. When this function is called, first the absolute value of the input is taken, and then the input is passed to the provided function.
42+
43+
Included in this class is the proximal map for FunctionOfAbs. From this, the proximal conjugate is also available from the parent CIL Function class, which is valid for this function.
44+
In the case that :math:`H` is lower semi-continuous, convex, non-decreasing and finite at the origin, (and thus `assume_lower_semi` is set to `True` in the `init`) the convex conjugate is also defined.
45+
The gradient is not defined for this function.
46+
47+
For reference: see https://doi.org/10.48550/arXiv.2410.22161
48+
49+
Parameters
50+
----------
51+
function : Function
52+
Function acting on a real input, :math:`H` in the above formula.
53+
assume_lower_semi : bool, default False
54+
If True, assume that the function is lower semi-continuous, convex, non-decreasing and finite at the origin.
55+
This allows the convex conjugate to be calculated as the monotone conjugate, which is less than or equal to the convex conjugate.
56+
If False, the convex conjugate returned as 0. This is to ensure compatibility with Algorithms such as PDHG.
57+
58+
'''
59+
60+
def __init__(self, function: Function, assume_lower_semi: bool=False):
61+
self._function = function
62+
self._lower_semi = assume_lower_semi
63+
64+
super().__init__(L=function.L)
65+
66+
def __call__(self, x: DataContainer) -> float:
67+
call_abs = _take_abs_input(self._function.__call__)
68+
return call_abs(self._function, x)
69+
70+
def proximal(self, x: DataContainer, tau: float, out: Optional[DataContainer]=None) -> DataContainer:
71+
r'''Returns the proximal map of function :math:`\tau G` evaluated at x
72+
73+
.. math:: \text{prox}_{\tau G}(x) = \underset{z}{\text{argmin}} \frac{1}{2}\|z - x\|^{2} + \tau G(z)
74+
75+
This is accomplished by calculating a bounded proximal map and making a change of phase,
76+
:math:`prox_G(z) = prox^+_H(r) \circ \Phi` where :math:`z = r \circ \Phi`, :math:`r = abs(z)`, :math:`\Phi = \exp(i \angle(z))`,
77+
and :math:`\circ` is element-wise product. Also define :math:`prox^+` to be the proximal map of :math:`H` in which the minimisation carried out over the positive orthant.
78+
79+
80+
Parameters
81+
----------
82+
x : DataContainer
83+
The input to the function
84+
tau: scalar
85+
The scalar multiplying the function in the proximal map
86+
out: return DataContainer, if None a new DataContainer is returned, default None.
87+
DataContainer to store the result of the proximal map
88+
89+
90+
Returns
91+
-------
92+
DataContainer, the proximal map of the function at x with scalar :math:`\tau`.
93+
94+
'''
95+
prox_abs = _abs_and_project(self._function.proximal)
96+
return prox_abs(self._function, x, tau=tau, out=out)
97+
98+
def convex_conjugate(self, x: DataContainer) -> float:
99+
r'''
100+
Evaluation of the function G* at x, where G* is the convex conjugate of function G,
101+
102+
.. math:: G^{*}(x^{*}) = \underset{x}{\sup} \langle x^{*}, x \rangle - G(x)
103+
104+
If :math:`H` is lower semi-continuous, convex, non-decreasing
105+
finite at the origin, then :math:`G^*(z*) = H^+(|z*|)`, where the monotone conjugate :math:`g^+` is
106+
107+
.. math:: H^+(z^*) =sup {(z, z^*) - H(z) : z >= O}
108+
109+
The monotone conjugate will therefore be less than or equal to the convex conjugate,
110+
since it is taken over a smaller set. It is not available directly, but may coincide with
111+
the convex conjugate, which is therefore the best estimate we have. This is only valid for
112+
real x. In other cases, a general convex conjugate is not available or defined.
113+
114+
115+
For reference see: Convex Analysis, R. Tyrrell Rocakfellar, pp110-111.
116+
117+
118+
Parameters
119+
----------
120+
x : DataContainer
121+
The input to the function
122+
123+
Returns
124+
-------
125+
float:
126+
The value of the convex conjugate of the function at x.
127+
'''
128+
129+
if self._lower_semi:
130+
conv_abs = _take_abs_input(self._function.convex_conjugate)
131+
return conv_abs(self._function, x)
132+
else:
133+
warnings.warn('Convex conjugate is not defined for this function, returning 0 for compatibility with optimisation algorithms')
134+
return 0.0
135+
136+
def gradient(self, x):
137+
'''Gradient of the function at x is not defined for this function.
138+
'''
139+
raise NotImplementedError('Gradient not available for this function')
140+
141+
142+
143+
144+
def _take_abs_input(func):
145+
'''Decorator for function to act on abs of input of a method'''
146+
147+
def _take_abs_decorator(self, x: DataContainer, *args, **kwargs):
148+
real_dtype, _ = _get_real_complex_dtype(x)
149+
150+
rgeo = x.geometry.copy()
151+
rgeo.dtype = real_dtype
152+
r = rgeo.allocate(0)
153+
r.fill(np.abs(x.as_array()).astype(real_dtype))
154+
fval = func(r, *args, **kwargs)
155+
return fval
156+
return _take_abs_decorator
157+
158+
159+
def _abs_and_project(func):
160+
'''Decorator for function to act on abs of input,
161+
with return being projected to the angle of the input.
162+
Requires function return to have the same shape as input,
163+
such as prox.'''
164+
165+
166+
def _abs_project_decorator(self, x: DataContainer, *args, **kwargs):
167+
168+
real_dtype, complex_dtype = _get_real_complex_dtype(x)
169+
170+
171+
rgeo = x.geometry.copy()
172+
rgeo.dtype = real_dtype
173+
r = rgeo.allocate(None)
174+
r.fill(np.abs(x.as_array()).astype(real_dtype))
175+
Phi = np.exp((1j*np.angle(x.array)))
176+
out = kwargs.pop('out', None)
177+
178+
fvals = func(r, *args, **kwargs)
179+
fvals_np = fvals.as_array()
180+
181+
# Douglas-Rachford splitting to find solution in positive orthant
182+
if np.any(fvals_np < 0):
183+
log.info('AbsFunctions: projection to +ve orthant triggered')
184+
cts = 0
185+
y = r.copy()
186+
fvals_np = fvals.as_array()
187+
while np.any(fvals_np < 0):
188+
tmp = fvals_np - 0.5*y.as_array() + 0.5*r.as_array()
189+
tmp[tmp < 0] = 0.
190+
y += DataContainter(tmp, y.geometry) - fvals
191+
fvals = func(y, *args, **kwargs)
192+
cts += 1
193+
if cts > 10:
194+
fvals_np = fvals.as_array()
195+
fvals_np[fvals_np < 0] = 0.
196+
break
197+
198+
if out is None:
199+
out_geom = x.geometry.copy()
200+
out = out_geom.allocate(None)
201+
if np.isreal(x.as_array()).all():
202+
out.fill( np.real(fvals_np.astype(complex_dtype)*Phi))
203+
else:
204+
out.fill( fvals_np.astype(complex_dtype)*Phi)
205+
return out
206+
return _abs_project_decorator
207+
208+
209+
def _get_real_complex_dtype(x: DataContainer):
210+
'''An internal function to find the type of x and set the corresponding real and complex data types '''
211+
212+
x_dtype = x.as_array().dtype
213+
if np.issubdtype(x_dtype, np.complexfloating):
214+
complex_dtype = x_dtype
215+
complex_example = np.array([1 + 1j], dtype=x_dtype)
216+
real_dtype = np.real(complex_example).dtype
217+
else:
218+
real_dtype = x_dtype
219+
complex_example = 1j*np.array([1], dtype=x_dtype)
220+
complex_dtype = complex_example.dtype
221+
return real_dtype, complex_dtype
222+

Wrappers/Python/cil/optimisation/functions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@
3939
from .SGFunction import SGFunction
4040
from .SVRGFunction import SVRGFunction, LSVRGFunction
4141
from .SAGFunction import SAGFunction, SAGAFunction
42+
from .AbsFunction import FunctionOfAbs
4243

Wrappers/Python/test/test_functions.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
L1Norm, MixedL21Norm, LeastSquares, \
3333
SmoothMixedL21Norm, OperatorCompositionFunction,\
3434
Rosenbrock, IndicatorBox, TotalVariation, ScaledFunction, SumFunction, SumScalarFunction, \
35-
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, L1Sparsity
35+
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, L1Sparsity, FunctionOfAbs
3636

3737
from cil.optimisation.functions import BlockFunction
3838

@@ -2124,3 +2124,105 @@ def test_set_num_threads(self):
21242124
N = 10
21252125
ib.set_num_threads(N)
21262126
assert ib.num_threads == N
2127+
2128+
2129+
class MockFunction(Function):
2130+
"""Mock function to test FunctionOfAbs"""
2131+
def __init__(self, L=1.0):
2132+
super().__init__(L=L)
2133+
2134+
def __call__(self, x):
2135+
return x.array**2 # Simple squared function
2136+
2137+
def proximal(self, x, tau, out=None):
2138+
if out is None:
2139+
out = x.geometry.allocate(None)
2140+
out.array = x.array / (1 + tau) # Simple proximal operator
2141+
return out
2142+
2143+
def convex_conjugate(self, x):
2144+
return 0.5 * x.array**2 # Quadratic conjugate function
2145+
2146+
2147+
class TestFunctionOfAbs(unittest.TestCase):
2148+
def setUp(self):
2149+
self.data_complex64 = VectorData(np.array([1+2j, -3+4j, -5-6j], dtype=np.complex64))
2150+
self.data_complex128 = VectorData(np.array([1+2j, -3+4j, -5-6j], dtype=np.complex128))
2151+
self.data_real32 = VectorData(np.array([1, 2, 3], dtype=np.float32))
2152+
self.mock_function = MockFunction()
2153+
self.abs_function = FunctionOfAbs(self.mock_function)
2154+
2155+
2156+
def test_initialization(self):
2157+
self.assertIsInstance(self.abs_function._function, MockFunction)
2158+
self.assertFalse(self.abs_function._lower_semi)
2159+
2160+
2161+
2162+
def test_call(self):
2163+
result = self.abs_function(self.data_complex64)
2164+
expected = np.abs(self.data_complex64.array)**2
2165+
np.testing.assert_array_almost_equal(result, expected, decimal=4)
2166+
2167+
result = self.abs_function(self.data_complex128)
2168+
expected = np.abs(self.data_complex128.array)**2
2169+
np.testing.assert_array_almost_equal(result, expected, decimal=6)
2170+
2171+
result = self.abs_function(self.data_real32)
2172+
expected = np.abs(self.data_real32.array)**2
2173+
np.testing.assert_array_almost_equal(result, expected, decimal=6)
2174+
2175+
def test_get_real_complex_dtype(self):
2176+
from cil.optimisation.functions.AbsFunction import _get_real_complex_dtype
2177+
real, compl = _get_real_complex_dtype(self.data_complex128)
2178+
self.assertEqual( real, np.float64)
2179+
self.assertEqual(compl, np.complex128)
2180+
2181+
real, compl = _get_real_complex_dtype(self.data_complex64)
2182+
self.assertEqual( real, np.float32)
2183+
self.assertEqual(compl, np.complex64)
2184+
2185+
real, compl = _get_real_complex_dtype(self.data_real32)
2186+
self.assertEqual( real, np.float32)
2187+
self.assertEqual(compl, np.complex64)
2188+
2189+
2190+
2191+
2192+
def test_proximal(self):
2193+
tau = 0.5
2194+
result = self.abs_function.proximal(self.data_complex128, tau)
2195+
expected = (np.abs(self.data_complex128.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_complex128.array))
2196+
np.testing.assert_array_almost_equal(result.array, expected, decimal=6)
2197+
2198+
result = self.abs_function.proximal(self.data_complex64, tau)
2199+
expected = (np.abs(self.data_complex64.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_complex64.array))
2200+
np.testing.assert_array_almost_equal(result.array, expected, decimal=4)
2201+
2202+
result = self.abs_function.proximal(self.data_real32, tau)
2203+
expected = (np.abs(self.data_real32.array) / (1 + tau)) * np.exp(1j * np.angle(self.data_real32.array))
2204+
np.testing.assert_array_almost_equal(result.array, expected, decimal=4)
2205+
2206+
2207+
2208+
def test_convex_conjugate_lower_semi(self):
2209+
self.abs_function._lower_semi = True
2210+
result = self.abs_function.convex_conjugate(self.data_complex128)
2211+
expected = 0.5 * np.abs(self.data_complex128.array) ** 2
2212+
np.testing.assert_array_almost_equal(result, expected, decimal=6)
2213+
2214+
result = self.abs_function.convex_conjugate(self.data_complex64)
2215+
expected = 0.5 * np.abs(self.data_complex64.array) ** 2
2216+
np.testing.assert_array_almost_equal(result, expected, decimal=4)
2217+
2218+
result = self.abs_function.convex_conjugate(self.data_real32)
2219+
expected = 0.5 * np.abs(self.data_real32.array) ** 2
2220+
np.testing.assert_array_almost_equal(result, expected, decimal=4)
2221+
2222+
2223+
def test_convex_conjugate_not_implemented(self):
2224+
self.abs_function._lower_semi = False
2225+
2226+
self.assertEqual(self.abs_function.convex_conjugate(self.data_real32), 0.)
2227+
2228+

Wrappers/Python/test/test_out_in_place.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
L1Norm, L2NormSquared, MixedL21Norm, LeastSquares, \
3939
SmoothMixedL21Norm, OperatorCompositionFunction, \
4040
IndicatorBox, TotalVariation, SumFunction, SumScalarFunction, \
41-
WeightedL2NormSquared, MixedL11Norm, ZeroFunction
41+
WeightedL2NormSquared, MixedL11Norm, ZeroFunction, FunctionOfAbs
4242

4343
from cil.processors import AbsorptionTransmissionConverter, Binner, CentreOfRotationCorrector, MaskGenerator, Masker, Normaliser, Padder, \
4444
RingRemover, Slicer, TransmissionAbsorptionConverter, PaganinProcessor, FluxNormaliser
@@ -87,6 +87,9 @@ def setUp(self):
8787
ag.set_panel(10)
8888

8989
ig = ag.get_ImageGeometry()
90+
91+
ig_complex = ig.copy()
92+
ig_complex.dtype = np.complex128
9093

9194
scalar = 4
9295

@@ -127,13 +130,13 @@ def setUp(self):
127130
(MixedL11Norm(), bg, True, True, False),
128131
(BlockFunction(L1Norm(),L2NormSquared()), bg, True, True, False),
129132
(BlockFunction(L2NormSquared(),L2NormSquared()), bg, True, True, True),
130-
(L1Sparsity(WaveletOperator(ig)), ig, True, True, False)
131-
132-
133+
(L1Sparsity(WaveletOperator(ig)), ig, True, True, False),
134+
(FunctionOfAbs(TotalVariation(backend='numpy', warm_start= False), assume_lower_semi=True), ig , True, True, False),
135+
(FunctionOfAbs(TotalVariation(backend='numpy', warm_start= False), assume_lower_semi=True), ig_complex , True, True, False)
133136
]
134137

135138
np.random.seed(5)
136-
self.data_arrays=[np.random.normal(0,1, (10,10)).astype(np.float32), np.array(range(0,65500, 655), dtype=np.uint16).reshape((10,10)), np.random.uniform(-0.1,1,(10,10)).astype(np.float32)]
139+
self.data_arrays=[np.random.normal(0,1, (10,10)).astype(np.float32), np.array(range(0,65500, 655), dtype=np.uint16).reshape((10,10)), np.random.uniform(-0.1,1,(10,10)).astype(np.float32) ]
137140

138141
def get_result(self, function, method, x, *args):
139142
try:

0 commit comments

Comments
 (0)