44 template <
class ObjectiveFunctionT,
class TargetT>
48 typedef ObjectiveFunctionT objective_function_type;
49 typedef TargetT target_type;
59 virtual Succeeded test_gradient(
const std::string& test_name,
60 ObjectiveFunctionT& objective_function,
63 const bool full_gradient =
true);
70 test_Hessian(
const std::string& test_name, ObjectiveFunctionT& objective_function,
const TargetT& target,
const float eps);
72 virtual Succeeded test_Hessian_concavity(
const std::string& test_name,
73 ObjectiveFunctionT& objective_function,
74 const TargetT& target,
75 const float mult_factor = 1.F);
83 virtual shared_ptr<const TargetT> construct_increment(
const TargetT& target,
const float eps)
const;
86 template <
class ObjectiveFunctionT,
class TargetT>
89 ObjectiveFunctionT& objective_function,
92 const bool full_gradient)
94 shared_ptr<TargetT> gradient_sptr(target.get_empty_copy());
95 shared_ptr<TargetT> gradient_2_sptr(target.get_empty_copy());
96 info(
"Computing gradient");
97 objective_function.compute_gradient(*gradient_sptr, target);
98 this->set_tolerance(std::max(fabs(
double(gradient_sptr->find_min())), double(gradient_sptr->find_max())) / 1000);
99 info(
"Computing objective function at target");
100 const double value_at_target = objective_function.compute_value(target);
104 info(
"Computing gradient of objective function by numerical differences (this will take a while)");
105 auto target_iter = target.begin_all();
106 auto gradient_iter = gradient_sptr->begin_all();
107 auto gradient_2_iter = gradient_2_sptr->begin_all();
108 while (target_iter != target.end_all())
111 const double value_at_inc = objective_function.compute_value(target);
113 const float gradient_at_iter =
static_cast<float>((value_at_inc - value_at_target) / eps);
114 *gradient_2_iter++ = gradient_at_iter;
115 testOK = testOK && this->check_if_equal(gradient_at_iter, *gradient_iter,
"gradient");
123 shared_ptr<const TargetT> increment_sptr = this->construct_increment(target, eps);
124 shared_ptr<TargetT> target_plus_inc_sptr(target.clone());
125 *target_plus_inc_sptr += *increment_sptr;
126 const double value_at_inc = objective_function.compute_value(*target_plus_inc_sptr);
128 gradient_sptr->begin_all_const(), gradient_sptr->end_all_const(), increment_sptr->begin_all_const(), 0.);
130 testOK = testOK && this->check_if_equal(value_at_inc - value_at_target, my_sum,
"gradient");
135 std::cerr <<
"Numerical gradient test failed with for " + test_name +
"\n";
136 std::cerr <<
"Writing diagnostic files " << test_name
137 <<
"_target.hv, *gradient.hv (and *numerical_gradient.hv if full gradient test is used)\n";
141 write_to_file(test_name +
"_numerical_gradient.hv", *gradient_2_sptr);
142 return Succeeded::no;
146 return Succeeded::yes;
150 template <
class ObjectiveFunctionT,
class TargetT>
151 shared_ptr<const TargetT>
154 shared_ptr<TargetT> increment_sptr(target.clone());
155 *increment_sptr *= eps / increment_sptr->find_max();
156 *increment_sptr += eps / 2;
157 return increment_sptr;
160 template <
class ObjectiveFunctionT,
class TargetT>
163 ObjectiveFunctionT& objective_function,
164 const TargetT& target,
167 info(
"Comparing Hessian*dx with difference of gradients");
170 shared_ptr<TargetT> gradient_sptr(target.get_empty_copy());
171 shared_ptr<TargetT> gradient_2_sptr(target.get_empty_copy());
172 shared_ptr<TargetT> output(target.get_empty_copy());
173 shared_ptr<const TargetT> increment_sptr = this->construct_increment(target, eps);
174 shared_ptr<TargetT> target_plus_inc_sptr(target.clone());
175 *target_plus_inc_sptr += *increment_sptr;
177 info(
"Computing gradient");
178 objective_function.compute_gradient(*gradient_sptr, target);
179 objective_function.compute_gradient(*gradient_2_sptr, *target_plus_inc_sptr);
180 this->set_tolerance(std::max(fabs(
double(gradient_sptr->find_min())), double(gradient_sptr->find_max())) / 1E5);
181 info(
"Computing Hessian * increment at target");
182 objective_function.accumulate_Hessian_times_input(*output, target, *increment_sptr);
183 auto output_iter = output->begin_all_const();
184 auto gradient_iter = gradient_sptr->begin_all_const();
185 auto gradient_2_iter = gradient_2_sptr->begin_all_const();
187 while (output_iter != output->end_all())
189 testOK = testOK && this->check_if_equal(*gradient_2_iter - *gradient_iter, *output_iter,
"Hessian*increment");
196 std::cerr <<
"Numerical Hessian test failed with for " + test_name +
"\n";
197 std::cerr <<
"Writing diagnostic files " << test_name
198 <<
"_target.hv, *gradient.hv, *increment, *numerical_gradient.hv, *Hessian_times_increment\n";
202 write_to_file(test_name +
"_gradient_at_increment.hv", *gradient_2_sptr);
203 write_to_file(test_name +
"_Hessian_times_increment.hv", *output);
204 return Succeeded::no;
208 return Succeeded::yes;
212 template <
class ObjectiveFunctionT,
class TargetT>
215 ObjectiveFunctionT& objective_function,
216 const TargetT& target,
217 const float mult_factor)
220 shared_ptr<TargetT> output(target.get_empty_copy());
223 objective_function.accumulate_Hessian_times_input(*output, target, target);
226 const float my_sum =
std::inner_product(target.begin_all(), target.end_all(), output->begin_all(), 0.F) * mult_factor;
229 if (this->check_if_less(my_sum, 0))
231 return Succeeded::yes;
236 std::cerr <<
"FAIL: " + test_name +
": Computation of x^T H x = " + std::to_string(my_sum)
237 <<
" > 0 (Hessian) and is therefore NOT concave" 238 <<
"\n >target image max=" << target.find_max() <<
"\n >target image min=" << target.find_min()
239 <<
"\n >output image max=" << output->find_max() <<
"\n >output image min=" << output->find_min() <<
'\n';
240 std::cerr <<
"Writing diagnostic files to " << test_name +
"_concavity_out.hv, *target.hv\n";
243 return Succeeded::no;
Declaration of class stir::Succeeded.
Test class for GeneralisedObjectiveFunction and GeneralisedPrior.
Definition: ObjectiveFunctionTests.h:45
void info(const STRING &string, const int verbosity_level=1)
Use this function for writing informational messages.
Definition: info.h:51
Declaration of stir::write_to_file function (providing easy access to the default stir::OutputFileFor...
std::string write_to_file(const std::string &filename, const DataT &data)
Function that writes data to file using the default OutputFileFormat.
Definition: write_to_file.h:47
coordT inner_product(const BasicCoordinate< num_dimensions, coordT > &p1, const BasicCoordinate< num_dimensions, coordT > &p2)
compute sum_i p1[i] * p2[i]
Definition: BasicCoordinate.inl:408
A base class for making test classesWith a derived class, an application could look like...
Definition: RunTests.h:71
Declaration of stir::info()
a class containing an enumeration type that can be used by functions to signal successful operation o...
Definition: Succeeded.h:43
defines the stir::RunTests class