STIR  6.2.0
ObjectiveFunctionTests.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2011, Hammersmith Imanet Ltd
3  Copyright (C) 2013, 2020, 2022-2024 University College London
4  This file is part of STIR.
5 
6  SPDX-License-Identifier: Apache-2.0
7 
8  See STIR/LICENSE.txt for details
9 */
21 #include "stir/RunTests.h"
22 #include "stir/IO/write_to_file.h"
23 #include "stir/info.h"
24 #include "stir/Succeeded.h"
25 #include <iostream>
26 
27 START_NAMESPACE_STIR
28 
44 template <class ObjectiveFunctionT, class TargetT>
46 {
47 public:
48  typedef ObjectiveFunctionT objective_function_type;
49  typedef TargetT target_type;
50 
52 
59  virtual Succeeded test_gradient(const std::string& test_name,
60  ObjectiveFunctionT& objective_function,
61  TargetT& target,
62  const float eps,
63  const bool full_gradient = true);
64 
66 
69  virtual Succeeded
70  test_Hessian(const std::string& test_name, ObjectiveFunctionT& objective_function, const TargetT& target, const float eps);
72  virtual Succeeded test_Hessian_concavity(const std::string& test_name,
73  ObjectiveFunctionT& objective_function,
74  const TargetT& target,
75  const float mult_factor = 1.F);
76 
77 protected:
79 
83  virtual shared_ptr<const TargetT> construct_increment(const TargetT& target, const float eps) const;
84 };
85 
86 template <class ObjectiveFunctionT, class TargetT>
89  ObjectiveFunctionT& objective_function,
90  TargetT& target,
91  const float eps,
92  const bool full_gradient)
93 {
94  shared_ptr<TargetT> gradient_sptr(target.get_empty_copy());
95  shared_ptr<TargetT> gradient_2_sptr(target.get_empty_copy());
96  info("Computing gradient");
97  objective_function.compute_gradient(*gradient_sptr, target);
98  this->set_tolerance(std::max(fabs(double(gradient_sptr->find_min())), double(gradient_sptr->find_max())) / 1000);
99  info("Computing objective function at target");
100  const double value_at_target = objective_function.compute_value(target);
101  bool testOK = true;
102  if (full_gradient)
103  {
104  info("Computing gradient of objective function by numerical differences (this will take a while)");
105  auto target_iter = target.begin_all();
106  auto gradient_iter = gradient_sptr->begin_all();
107  auto gradient_2_iter = gradient_2_sptr->begin_all();
108  while (target_iter != target.end_all())
109  {
110  *target_iter += eps;
111  const double value_at_inc = objective_function.compute_value(target);
112  *target_iter -= eps;
113  const float gradient_at_iter = static_cast<float>((value_at_inc - value_at_target) / eps);
114  *gradient_2_iter++ = gradient_at_iter;
115  testOK = testOK && this->check_if_equal(gradient_at_iter, *gradient_iter, "gradient");
116  ++target_iter;
117  ++gradient_iter;
118  }
119  }
120  else
121  {
122  /* test f(x+dx) - f(x) ~ dx^t G(x) */
123  shared_ptr<const TargetT> increment_sptr = this->construct_increment(target, eps);
124  shared_ptr<TargetT> target_plus_inc_sptr(target.clone());
125  *target_plus_inc_sptr += *increment_sptr;
126  const double value_at_inc = objective_function.compute_value(*target_plus_inc_sptr);
127  const double my_sum = std::inner_product(
128  gradient_sptr->begin_all_const(), gradient_sptr->end_all_const(), increment_sptr->begin_all_const(), 0.);
129 
130  testOK = testOK && this->check_if_equal(value_at_inc - value_at_target, my_sum, "gradient");
131  }
132 
133  if (!testOK)
134  {
135  std::cerr << "Numerical gradient test failed with for " + test_name + "\n";
136  std::cerr << "Writing diagnostic files " << test_name
137  << "_target.hv, *gradient.hv (and *numerical_gradient.hv if full gradient test is used)\n";
138  write_to_file(test_name + "_target.hv", target);
139  write_to_file(test_name + "_gradient.hv", *gradient_sptr);
140  if (full_gradient)
141  write_to_file(test_name + "_numerical_gradient.hv", *gradient_2_sptr);
142  return Succeeded::no;
143  }
144  else
145  {
146  return Succeeded::yes;
147  }
148 }
149 
150 template <class ObjectiveFunctionT, class TargetT>
151 shared_ptr<const TargetT>
153 {
154  shared_ptr<TargetT> increment_sptr(target.clone());
155  *increment_sptr *= eps / increment_sptr->find_max();
156  *increment_sptr += eps / 2;
157  return increment_sptr;
158 }
159 
160 template <class ObjectiveFunctionT, class TargetT>
161 Succeeded
163  ObjectiveFunctionT& objective_function,
164  const TargetT& target,
165  const float eps)
166 {
167  info("Comparing Hessian*dx with difference of gradients");
168 
169  /* test G(x+dx) = G(x) + H dx + small stuff */
170  shared_ptr<TargetT> gradient_sptr(target.get_empty_copy());
171  shared_ptr<TargetT> gradient_2_sptr(target.get_empty_copy());
172  shared_ptr<TargetT> output(target.get_empty_copy());
173  shared_ptr<const TargetT> increment_sptr = this->construct_increment(target, eps);
174  shared_ptr<TargetT> target_plus_inc_sptr(target.clone());
175  *target_plus_inc_sptr += *increment_sptr;
176 
177  info("Computing gradient");
178  objective_function.compute_gradient(*gradient_sptr, target);
179  objective_function.compute_gradient(*gradient_2_sptr, *target_plus_inc_sptr);
180  this->set_tolerance(std::max(fabs(double(gradient_sptr->find_min())), double(gradient_sptr->find_max())) / 1E5);
181  info("Computing Hessian * increment at target");
182  objective_function.accumulate_Hessian_times_input(*output, target, *increment_sptr);
183  auto output_iter = output->begin_all_const();
184  auto gradient_iter = gradient_sptr->begin_all_const();
185  auto gradient_2_iter = gradient_2_sptr->begin_all_const();
186  bool testOK = true;
187  while (output_iter != output->end_all())
188  {
189  testOK = testOK && this->check_if_equal(*gradient_2_iter - *gradient_iter, *output_iter, "Hessian*increment");
190  ++output_iter;
191  ++gradient_iter;
192  ++gradient_2_iter;
193  }
194  if (!testOK)
195  {
196  std::cerr << "Numerical Hessian test failed with for " + test_name + "\n";
197  std::cerr << "Writing diagnostic files " << test_name
198  << "_target.hv, *gradient.hv, *increment, *numerical_gradient.hv, *Hessian_times_increment\n";
199  write_to_file(test_name + "_target.hv", target);
200  write_to_file(test_name + "_gradient.hv", *gradient_sptr);
201  write_to_file(test_name + "_increment.hv", *increment_sptr);
202  write_to_file(test_name + "_gradient_at_increment.hv", *gradient_2_sptr);
203  write_to_file(test_name + "_Hessian_times_increment.hv", *output);
204  return Succeeded::no;
205  }
206  else
207  {
208  return Succeeded::yes;
209  }
210 }
211 
212 template <class ObjectiveFunctionT, class TargetT>
213 Succeeded
215  ObjectiveFunctionT& objective_function,
216  const TargetT& target,
217  const float mult_factor)
218 {
220  shared_ptr<TargetT> output(target.get_empty_copy());
221 
223  objective_function.accumulate_Hessian_times_input(*output, target, target);
224 
226  const float my_sum = std::inner_product(target.begin_all(), target.end_all(), output->begin_all(), 0.F) * mult_factor;
227 
228  // test for a CONCAVE function (after multiplying with mult_factor of course)
229  if (this->check_if_less(my_sum, 0))
230  {
231  return Succeeded::yes;
232  }
233  else
234  {
235  // print to console the FAILED configuration
236  std::cerr << "FAIL: " + test_name + ": Computation of x^T H x = " + std::to_string(my_sum)
237  << " > 0 (Hessian) and is therefore NOT concave"
238  << "\n >target image max=" << target.find_max() << "\n >target image min=" << target.find_min()
239  << "\n >output image max=" << output->find_max() << "\n >output image min=" << output->find_min() << '\n';
240  std::cerr << "Writing diagnostic files to " << test_name + "_concavity_out.hv, *target.hv\n";
241  write_to_file(test_name + "_concavity_out.hv", *output);
242  write_to_file(test_name + "_target.hv", target);
243  return Succeeded::no;
244  }
245 }
246 
247 END_NAMESPACE_STIR
Declaration of class stir::Succeeded.
Test class for GeneralisedObjectiveFunction and GeneralisedPrior.
Definition: ObjectiveFunctionTests.h:45
void info(const STRING &string, const int verbosity_level=1)
Use this function for writing informational messages.
Definition: info.h:51
Declaration of stir::write_to_file function (providing easy access to the default stir::OutputFileFor...
std::string write_to_file(const std::string &filename, const DataT &data)
Function that writes data to file using the default OutputFileFormat.
Definition: write_to_file.h:47
coordT inner_product(const BasicCoordinate< num_dimensions, coordT > &p1, const BasicCoordinate< num_dimensions, coordT > &p2)
compute sum_i p1[i] * p2[i]
Definition: BasicCoordinate.inl:408
A base class for making test classesWith a derived class, an application could look like...
Definition: RunTests.h:71
Declaration of stir::info()
a class containing an enumeration type that can be used by functions to signal successful operation o...
Definition: Succeeded.h:43
defines the stir::RunTests class