STIR 6.4.0
ObjectiveFunctionTests.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2011, Hammersmith Imanet Ltd
3 Copyright (C) 2013, 2020, 2022-2024 University College London
4 This file is part of STIR.
5
6 SPDX-License-Identifier: Apache-2.0
7
8 See STIR/LICENSE.txt for details
9*/
20
21#include "stir/RunTests.h"
23#include "stir/info.h"
24#include "stir/Succeeded.h"
25#include <iostream>
26
27START_NAMESPACE_STIR
28
44template <class ObjectiveFunctionT, class TargetT>
46{
47public:
48 typedef ObjectiveFunctionT objective_function_type;
49 typedef TargetT target_type;
50
52
59 virtual Succeeded test_gradient(const std::string& test_name,
60 ObjectiveFunctionT& objective_function,
61 TargetT& target,
62 const float eps,
63 const bool full_gradient = true);
64
66
69 virtual Succeeded
70 test_Hessian(const std::string& test_name, ObjectiveFunctionT& objective_function, const TargetT& target, const float eps);
72 virtual Succeeded test_Hessian_concavity(const std::string& test_name,
73 ObjectiveFunctionT& objective_function,
74 const TargetT& target,
75 const float mult_factor = 1.F);
76
77protected:
79
83 virtual shared_ptr<const TargetT> construct_increment(const TargetT& target, const float eps) const;
84};
85
86template <class ObjectiveFunctionT, class TargetT>
89 ObjectiveFunctionT& objective_function,
90 TargetT& target,
91 const float eps,
92 const bool full_gradient)
93{
94 shared_ptr<TargetT> gradient_sptr(target.get_empty_copy());
95 shared_ptr<TargetT> gradient_2_sptr(target.get_empty_copy());
96 info("Computing gradient");
97 objective_function.compute_gradient(*gradient_sptr, target);
98 this->set_tolerance(std::max(fabs(double(gradient_sptr->find_min())), double(gradient_sptr->find_max())) / 1000);
99 info("Computing objective function at target");
100 const double value_at_target = objective_function.compute_value(target);
101 bool testOK = true;
102 if (full_gradient)
103 {
104 info("Computing gradient of objective function by numerical differences (this will take a while)");
105 auto target_iter = target.begin_all();
106 auto gradient_iter = gradient_sptr->begin_all();
107 auto gradient_2_iter = gradient_2_sptr->begin_all();
108 while (target_iter != target.end_all())
109 {
110 *target_iter += eps;
111 const double value_at_inc = objective_function.compute_value(target);
112 *target_iter -= eps;
113 const float gradient_at_iter = static_cast<float>((value_at_inc - value_at_target) / eps);
114 *gradient_2_iter++ = gradient_at_iter;
115 testOK = testOK && this->check_if_equal(gradient_at_iter, *gradient_iter, "gradient");
116 ++target_iter;
117 ++gradient_iter;
118 }
119 }
120 else
121 {
122 /* test f(x+dx) - f(x) ~ dx^t G(x) */
123 shared_ptr<const TargetT> increment_sptr = this->construct_increment(target, eps);
124 shared_ptr<TargetT> target_plus_inc_sptr(target.clone());
125 *target_plus_inc_sptr += *increment_sptr;
126 const double value_at_inc = objective_function.compute_value(*target_plus_inc_sptr);
127 const double my_sum = std::inner_product(
128 gradient_sptr->begin_all_const(), gradient_sptr->end_all_const(), increment_sptr->begin_all_const(), 0.);
129
130 testOK = testOK && this->check_if_equal(value_at_inc - value_at_target, my_sum, "gradient");
131 }
132
133 if (!testOK)
134 {
135 std::cerr << "Numerical gradient test failed with for " + test_name + "\n";
136 std::cerr << "Writing diagnostic files " << test_name
137 << "_target.hv, *gradient.hv (and *numerical_gradient.hv if full gradient test is used)\n";
138 write_to_file(test_name + "_target.hv", target);
139 write_to_file(test_name + "_gradient.hv", *gradient_sptr);
140 if (full_gradient)
141 write_to_file(test_name + "_numerical_gradient.hv", *gradient_2_sptr);
142 return Succeeded::no;
143 }
144 else
145 {
146 return Succeeded::yes;
147 }
148}
149
150template <class ObjectiveFunctionT, class TargetT>
151shared_ptr<const TargetT>
153{
154 shared_ptr<TargetT> increment_sptr(target.clone());
155 *increment_sptr *= eps / increment_sptr->find_max();
156 *increment_sptr += eps / 2;
157 return increment_sptr;
158}
159
160template <class ObjectiveFunctionT, class TargetT>
163 ObjectiveFunctionT& objective_function,
164 const TargetT& target,
165 const float eps)
166{
167 info("Comparing Hessian*dx with difference of gradients");
168
169 /* test G(x+dx) = G(x) + H dx + small stuff */
170 shared_ptr<TargetT> gradient_sptr(target.get_empty_copy());
171 shared_ptr<TargetT> gradient_2_sptr(target.get_empty_copy());
172 shared_ptr<TargetT> output(target.get_empty_copy());
173 shared_ptr<const TargetT> increment_sptr = this->construct_increment(target, eps);
174 shared_ptr<TargetT> target_plus_inc_sptr(target.clone());
175 *target_plus_inc_sptr += *increment_sptr;
176
177 info("Computing gradient");
178 objective_function.compute_gradient(*gradient_sptr, target);
179 objective_function.compute_gradient(*gradient_2_sptr, *target_plus_inc_sptr);
180 this->set_tolerance(std::max(fabs(double(gradient_sptr->find_min())), double(gradient_sptr->find_max())) / 1E5);
181 info("Computing Hessian * increment at target");
182 objective_function.accumulate_Hessian_times_input(*output, target, *increment_sptr);
183 auto output_iter = output->begin_all_const();
184 auto gradient_iter = gradient_sptr->begin_all_const();
185 auto gradient_2_iter = gradient_2_sptr->begin_all_const();
186 bool testOK = true;
187 while (output_iter != output->end_all())
188 {
189 testOK = testOK && this->check_if_equal(*gradient_2_iter - *gradient_iter, *output_iter, "Hessian*increment");
190 ++output_iter;
191 ++gradient_iter;
192 ++gradient_2_iter;
193 }
194 if (!testOK)
195 {
196 std::cerr << "Numerical Hessian test failed with for " + test_name + "\n";
197 std::cerr << "Writing diagnostic files " << test_name
198 << "_target.hv, *gradient.hv, *increment, *numerical_gradient.hv, *Hessian_times_increment\n";
199 write_to_file(test_name + "_target.hv", target);
200 write_to_file(test_name + "_gradient.hv", *gradient_sptr);
201 write_to_file(test_name + "_increment.hv", *increment_sptr);
202 write_to_file(test_name + "_gradient_at_increment.hv", *gradient_2_sptr);
203 write_to_file(test_name + "_Hessian_times_increment.hv", *output);
204 return Succeeded::no;
205 }
206 else
207 {
208 return Succeeded::yes;
209 }
210}
211
212template <class ObjectiveFunctionT, class TargetT>
215 ObjectiveFunctionT& objective_function,
216 const TargetT& target,
217 const float mult_factor)
218{
220 shared_ptr<TargetT> output(target.get_empty_copy());
221
223 objective_function.accumulate_Hessian_times_input(*output, target, target);
224
226 const float my_sum = std::inner_product(target.begin_all(), target.end_all(), output->begin_all(), 0.F) * mult_factor;
227
228 // test for a CONCAVE function (after multiplying with mult_factor of course)
229 if (this->check_if_less(my_sum, 0))
230 {
231 return Succeeded::yes;
232 }
233 else
234 {
235 // print to console the FAILED configuration
236 std::cerr << "FAIL: " + test_name + ": Computation of x^T H x = " + std::to_string(my_sum)
237 << " > 0 (Hessian) and is therefore NOT concave"
238 << "\n >target image max=" << target.find_max() << "\n >target image min=" << target.find_min()
239 << "\n >output image max=" << output->find_max() << "\n >output image min=" << output->find_min() << '\n';
240 std::cerr << "Writing diagnostic files to " << test_name + "_concavity_out.hv, *target.hv\n";
241 write_to_file(test_name + "_concavity_out.hv", *output);
242 write_to_file(test_name + "_target.hv", target);
243 return Succeeded::no;
244 }
245}
246
247END_NAMESPACE_STIR
defines the stir::RunTests class
Declaration of class stir::Succeeded.
Test class for GeneralisedObjectiveFunction and GeneralisedPrior.
Definition ObjectiveFunctionTests.h:46
virtual shared_ptr< const TargetT > construct_increment(const TargetT &target, const float eps) const
Construct small increment for target.
Definition ObjectiveFunctionTests.h:152
virtual Succeeded test_Hessian(const std::string &test_name, ObjectiveFunctionT &objective_function, const TargetT &target, const float eps)
Test the accumulate_Hessian_times_input of the objective function by comparing to the numerical resul...
Definition ObjectiveFunctionTests.h:162
virtual Succeeded test_gradient(const std::string &test_name, ObjectiveFunctionT &objective_function, TargetT &target, const float eps, const bool full_gradient=true)
Test the gradient of the objective function by comparing to the numerical gradient via perturbation.
Definition ObjectiveFunctionTests.h:88
virtual Succeeded test_Hessian_concavity(const std::string &test_name, ObjectiveFunctionT &objective_function, const TargetT &target, const float mult_factor=1.F)
Test the Hessian of the objective function by testing the (mult_factor * x^T Hx > 0) condition.
Definition ObjectiveFunctionTests.h:214
bool check_if_less(T1 a, T2 b, const std::string &str="")
check if a<b
Definition RunTests.h:517
RunTests(const double tolerance=1E-4)
Default constructor.
Definition RunTests.h:305
void set_tolerance(const double tolerance)
Set value used in floating point comparisons (see check_* functions)
Definition RunTests.h:335
a class containing an enumeration type that can be used by functions to signal successful operation o...
Definition Succeeded.h:44
std::string write_to_file(const std::string &filename, const DataT &data)
Function that writes data to file using the default OutputFileFormat.
Definition write_to_file.h:46
void info(const STRING &string, const int verbosity_level=1)
Use this function for writing informational messages.
Definition info.h:51
Declaration of stir::info()
Declaration of stir::write_to_file function (providing easy access to the default stir::OutputFileFor...