STIR  6.2.0
cuda_utilities.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2024, University College London
3  This file is part of STIR.
4 
5  SPDX-License-Identifier: Apache-2.0
6 
7  See STIR/LICENSE.txt for details
8 */
9 
10 #ifndef __stir_cuda_utilities_H__
11 #define __stir_cuda_utilities_H__
12 
20 #include "stir/Array.h"
21 #include "stir/info.h"
22 #include <vector>
23 
24 START_NAMESPACE_STIR
25 
26 template <int num_dimensions, typename elemT>
27 inline void
28 array_to_device(elemT* dev_data, const Array<num_dimensions, elemT>& stir_array)
29 {
30  if (stir_array.is_contiguous())
31  {
32  info("array_to_device contiguous", 100);
33  cudaMemcpy(dev_data, stir_array.get_const_full_data_ptr(), stir_array.size_all() * sizeof(elemT), cudaMemcpyHostToDevice);
34  stir_array.release_const_full_data_ptr();
35  }
36  else
37  {
38  info("array_to_device non-contiguous", 100);
39  // Allocate host memory to get contiguous vector, copy array to it and copy from device to host
40  std::vector<elemT> tmp_data(stir_array.size_all());
41  std::copy(stir_array.begin_all(), stir_array.end_all(), tmp_data.begin());
42  cudaMemcpy(dev_data, tmp_data.data(), stir_array.size_all() * sizeof(elemT), cudaMemcpyHostToDevice);
43  }
44 }
45 
46 template <int num_dimensions, typename elemT>
47 inline void
48 array_to_host(Array<num_dimensions, elemT>& stir_array, const elemT* dev_data)
49 {
50  if (stir_array.is_contiguous())
51  {
52  info("array_to_host contiguous", 100);
53  cudaMemcpy(stir_array.get_full_data_ptr(), dev_data, stir_array.size_all() * sizeof(elemT), cudaMemcpyDeviceToHost);
54  stir_array.release_full_data_ptr();
55  }
56  else
57  {
58  info("array_to_host non-contiguous", 100);
59  // Allocate host memory for the result and copy from device to host
60  std::vector<elemT> tmp_data(stir_array.size_all());
61  cudaMemcpy(tmp_data.data(), dev_data, stir_array.size_all() * sizeof(elemT), cudaMemcpyDeviceToHost);
62  // Copy the data to the stir_array
63  std::copy(tmp_data.begin(), tmp_data.end(), stir_array.begin_all());
64  }
65 }
66 
67 END_NAMESPACE_STIR
68 
69 #endif
void info(const STRING &string, const int verbosity_level=1)
Use this function for writing informational messages.
Definition: info.h:51
defines the Array class for multi-dimensional (numeric) arrays
Declaration of stir::info()