Sacado Package Browser (Single Doxygen Collection)  Version of the Day
Sacado_Fad_ScalarTraitsImp.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ***********************************************************************
3 //
4 // Sacado Package
5 // Copyright (2006) Sandia Corporation
6 //
7 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8 // the U.S. Government retains certain rights in this software.
9 //
10 // This library is free software; you can redistribute it and/or modify
11 // it under the terms of the GNU Lesser General Public License as
12 // published by the Free Software Foundation; either version 2.1 of the
13 // License, or (at your option) any later version.
14 //
15 // This library is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 // Lesser General Public License for more details.
19 //
20 // You should have received a copy of the GNU Lesser General Public
21 // License along with this library; if not, write to the Free Software
22 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23 // USA
24 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
25 // (etphipp@sandia.gov).
26 //
27 // ***********************************************************************
28 // @HEADER
29 
30 #ifndef SACADO_FAD_SCALARTRAITSIMP_HPP
31 #define SACADO_FAD_SCALARTRAITSIMP_HPP
32 
33 #ifdef HAVE_SACADO_TEUCHOS
34 
35 #include "Teuchos_ScalarTraits.hpp"
36 #include "Teuchos_SerializationTraits.hpp"
37 #include "Teuchos_SerializationTraitsHelpers.hpp"
38 #include "Teuchos_Assert.hpp"
39 #include "Teuchos_RCP.hpp"
40 #include "Teuchos_Array.hpp"
41 #include "Sacado_mpl_apply.hpp"
42 
43 #include <iterator>
44 
45 namespace Sacado {
46 
47  namespace Fad {
48 
50  template <typename FadType>
51  struct ScalarTraitsImp {
52  typedef typename Sacado::ValueType<FadType>::type ValueT;
53 
54  typedef typename mpl::apply<FadType,typename Teuchos::ScalarTraits<ValueT>::magnitudeType>::type magnitudeType;
55  typedef typename mpl::apply<FadType,typename Teuchos::ScalarTraits<ValueT>::halfPrecision>::type halfPrecision;
56  typedef typename mpl::apply<FadType,typename Teuchos::ScalarTraits<ValueT>::doublePrecision>::type doublePrecision;
57 
58  static const bool isComplex = Teuchos::ScalarTraits<ValueT>::isComplex;
59  static const bool isOrdinal = Teuchos::ScalarTraits<ValueT>::isOrdinal;
60  static const bool isComparable =
61  Teuchos::ScalarTraits<ValueT>::isComparable;
62  static const bool hasMachineParameters =
63  Teuchos::ScalarTraits<ValueT>::hasMachineParameters;
64  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType eps() {
65  return Teuchos::ScalarTraits<ValueT>::eps();
66  }
67  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType sfmin() {
68  return Teuchos::ScalarTraits<ValueT>::sfmin();
69  }
70  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType base() {
71  return Teuchos::ScalarTraits<ValueT>::base();
72  }
73  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType prec() {
74  return Teuchos::ScalarTraits<ValueT>::prec();
75  }
76  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType t() {
77  return Teuchos::ScalarTraits<ValueT>::t();
78  }
79  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType rnd() {
81  }
82  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType emin() {
83  return Teuchos::ScalarTraits<ValueT>::emin();
84  }
85  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType rmin() {
86  return Teuchos::ScalarTraits<ValueT>::rmin();
87  }
88  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType emax() {
89  return Teuchos::ScalarTraits<ValueT>::emax();
90  }
91  static typename Teuchos::ScalarTraits<ValueT>::magnitudeType rmax() {
92  return Teuchos::ScalarTraits<ValueT>::rmax();
93  }
94  static magnitudeType magnitude(const FadType& a) {
95 #ifdef TEUCHOS_DEBUG
96  TEUCHOS_SCALAR_TRAITS_NAN_INF_ERR(
97  a, "Error, the input value to magnitude(...) a = " << a <<
98  " can not be NaN!" );
99  TEUCHOS_TEST_FOR_EXCEPTION(is_fad_real(a) == false, std::runtime_error,
100  "Complex magnitude is not a differentiable "
101  "function of complex inputs.");
102 #endif
103  //return std::fabs(a);
104  magnitudeType b(a.size(),
105  Teuchos::ScalarTraits<ValueT>::magnitude(a.val()));
106  if (Teuchos::ScalarTraits<ValueT>::real(a.val()) >= 0)
107  for (int i=0; i<a.size(); i++)
108  b.fastAccessDx(i) =
109  Teuchos::ScalarTraits<ValueT>::magnitude(a.fastAccessDx(i));
110  else
111  for (int i=0; i<a.size(); i++)
112  b.fastAccessDx(i) =
113  -Teuchos::ScalarTraits<ValueT>::magnitude(a.fastAccessDx(i));
114  return b;
115  }
116  static ValueT zero() {
117  return ValueT(0.0);
118  }
119  static ValueT one() {
120  return ValueT(1.0);
121  }
122 
123  // Conjugate is only defined for real derivative components
124  static FadType conjugate(const FadType& x) {
125 #ifdef TEUCHOS_DEBUG
126  TEUCHOS_TEST_FOR_EXCEPTION(is_fad_real(x) == false, std::runtime_error,
127  "Complex conjugate is not a differentiable "
128  "function of complex inputs.");
129 #endif
130  FadType y = x;
131  y.val() = Teuchos::ScalarTraits<ValueT>::conjugate(x.val());
132  return y;
133  }
134 
135  // Real part is only defined for real derivative components
136  static FadType real(const FadType& x) {
137 #ifdef TEUCHOS_DEBUG
138  TEUCHOS_TEST_FOR_EXCEPTION(is_fad_real(x) == false, std::runtime_error,
139  "Real component is not a differentiable "
140  "function of complex inputs.");
141 #endif
142  FadType y = x;
143  y.val() = Teuchos::ScalarTraits<ValueT>::real(x.val());
144  return y;
145  }
146 
147  // Imaginary part is only defined for real derivative components
148  static FadType imag(const FadType& x) {
149 #ifdef TEUCHOS_DEBUG
150  TEUCHOS_TEST_FOR_EXCEPTION(is_fad_real(x) == false, std::runtime_error,
151  "Imaginary component is not a differentiable "
152  "function of complex inputs.");
153 #endif
154  return FadType(Teuchos::ScalarTraits<ValueT>::imag(x.val()));
155  }
156 
157  static ValueT nan() {
158  return Teuchos::ScalarTraits<ValueT>::nan();
159  }
160  static bool isnaninf(const FadType& x) {
161  if (Teuchos::ScalarTraits<ValueT>::isnaninf(x.val()))
162  return true;
163  for (int i=0; i<x.size(); i++)
164  if (Teuchos::ScalarTraits<ValueT>::isnaninf(x.dx(i)))
165  return true;
166  return false;
167  }
168  static void seedrandom(unsigned int s) {
169  Teuchos::ScalarTraits<ValueT>::seedrandom(s);
170  }
171  static ValueT random() {
172  return Teuchos::ScalarTraits<ValueT>::random();
173  }
174  static std::string name() {
176  }
177  static FadType squareroot(const FadType& x) {
178 #ifdef TEUCHOS_DEBUG
179  TEUCHOS_SCALAR_TRAITS_NAN_INF_ERR(
180  x, "Error, the input value to squareroot(...) a = " << x <<
181  " can not be NaN!" );
182 #endif
183  return std::sqrt(x);
184  }
185  static FadType pow(const FadType& x, const FadType& y) {
186  return std::pow(x,y);
187  }
188 
189  // Helper function to determine whether a complex value is real
190  static bool is_complex_real(const ValueT& x) {
191  return
192  Teuchos::ScalarTraits<ValueT>::magnitude(x-Teuchos::ScalarTraits<ValueT>::real(x)) == 0;
193  }
194 
195  // Helper function to determine whether a Fad type is real
196  static bool is_fad_real(const FadType& x) {
197  if (x.size() == 0)
198  return true;
199  if (Teuchos::ScalarTraits<ValueT>::isComplex) {
200  if (!is_complex_real(x.val()))
201  return false;
202  for (int i=0; i<x.size(); i++)
203  if (!is_complex_real(x.fastAccessDx(i)))
204  return false;
205  }
206  return true;
207  }
208 
209  }; // class ScalarTraitsImp
210 
212  template <typename Ordinal, typename FadType, typename Serializer>
213  struct SerializationImp {
214 
215  private:
216 
218  typedef Teuchos::SerializationTraits<Ordinal,int> iSerT;
219 
221  typedef Teuchos::SerializationTraits<Ordinal,Ordinal> oSerT;
222 
224  typedef typename Sacado::ValueType<FadType>::type value_type;
225 
226  public:
227 
229  static const bool supportsDirectSerialization = false;
230 
232 
233 
235  static Ordinal fromCountToIndirectBytes(const Serializer& vs,
236  const Ordinal count,
237  const FadType buffer[],
238  const Ordinal sz = 0) {
239  Ordinal bytes = 0;
240  FadType *x = NULL;
241  const FadType *cx;
242  for (Ordinal i=0; i<count; i++) {
243  int my_sz = buffer[i].size();
244  int tot_sz = sz;
245  if (sz == 0) tot_sz = my_sz;
246  Ordinal b1 = iSerT::fromCountToIndirectBytes(1, &tot_sz);
247  Ordinal b2 = vs.fromCountToIndirectBytes(1, &(buffer[i].val()));
248  Ordinal b3 = oSerT::fromCountToIndirectBytes(1, &b2);
249  Ordinal b4;
250  if (tot_sz != my_sz) {
251  if (x == NULL)
252  x = new FadType(tot_sz, 0.0);
253  *x = buffer[i];
254  x->expand(tot_sz);
255  cx = x;
256  }
257  else
258  cx = &(buffer[i]);
259  b4 = vs.fromCountToIndirectBytes(tot_sz, cx->dx());
260  Ordinal b5 = oSerT::fromCountToIndirectBytes(1, &b4);
261  bytes += b1+b2+b3+b4+b5;
262  }
263  if (x != NULL)
264  delete x;
265  return bytes;
266  }
267 
269  static void serialize (const Serializer& vs,
270  const Ordinal count,
271  const FadType buffer[],
272  const Ordinal bytes,
273  char charBuffer[],
274  const Ordinal sz = 0) {
275  FadType *x = NULL;
276  const FadType *cx;
277  for (Ordinal i=0; i<count; i++) {
278  // First serialize size
279  int my_sz = buffer[i].size();
280  int tot_sz = sz;
281  if (sz == 0) tot_sz = my_sz;
282  Ordinal b1 = iSerT::fromCountToIndirectBytes(1, &tot_sz);
283  iSerT::serialize(1, &tot_sz, b1, charBuffer);
284  charBuffer += b1;
285 
286  // Next serialize value
287  Ordinal b2 = vs.fromCountToIndirectBytes(1, &(buffer[i].val()));
288  Ordinal b3 = oSerT::fromCountToIndirectBytes(1, &b2);
289  oSerT::serialize(1, &b2, b3, charBuffer);
290  charBuffer += b3;
291  vs.serialize(1, &(buffer[i].val()), b2, charBuffer);
292  charBuffer += b2;
293 
294  // Next serialize derivative array
295  Ordinal b4;
296  if (tot_sz != my_sz) {
297  if (x == NULL)
298  x = new FadType(tot_sz, 0.0);
299  *x = buffer[i];
300  x->expand(tot_sz);
301  cx = x;
302  }
303  else
304  cx = &(buffer[i]);
305  b4 = vs.fromCountToIndirectBytes(tot_sz, cx->dx());
306  Ordinal b5 = oSerT::fromCountToIndirectBytes(1, &b4);
307  oSerT::serialize(1, &b4, b5, charBuffer);
308  charBuffer += b5;
309  vs.serialize(tot_sz, cx->dx(), b4, charBuffer);
310  charBuffer += b4;
311  }
312  if (x != NULL)
313  delete x;
314  }
315 
317  static Ordinal fromIndirectBytesToCount(const Serializer& vs,
318  const Ordinal bytes,
319  const char charBuffer[],
320  const Ordinal sz = 0) {
321  Ordinal count = 0;
322  Ordinal bytes_used = 0;
323  while (bytes_used < bytes) {
324 
325  // Bytes for size
326  Ordinal b1 = iSerT::fromCountToDirectBytes(1);
327  bytes_used += b1;
328  charBuffer += b1;
329 
330  // Bytes for value
331  Ordinal b3 = oSerT::fromCountToDirectBytes(1);
332  const Ordinal *b2 = oSerT::convertFromCharPtr(charBuffer);
333  bytes_used += b3;
334  charBuffer += b3;
335  bytes_used += *b2;
336  charBuffer += *b2;
337 
338  // Bytes for derivative array
339  Ordinal b5 = oSerT::fromCountToDirectBytes(1);
340  const Ordinal *b4 = oSerT::convertFromCharPtr(charBuffer);
341  bytes_used += b5;
342  charBuffer += b5;
343  bytes_used += *b4;
344  charBuffer += *b4;
345 
346  ++count;
347  }
348  return count;
349  }
350 
352  static void deserialize (const Serializer& vs,
353  const Ordinal bytes,
354  const char charBuffer[],
355  const Ordinal count,
356  FadType buffer[],
357  const Ordinal sz = 0) {
358  for (Ordinal i=0; i<count; i++) {
359 
360  // Deserialize size
361  Ordinal b1 = iSerT::fromCountToDirectBytes(1);
362  const int *my_sz = iSerT::convertFromCharPtr(charBuffer);
363  charBuffer += b1;
364 
365  // Create empty Fad object of given size
366  int tot_sz = sz;
367  if (sz == 0) tot_sz = *my_sz;
368  buffer[i] = FadType(tot_sz, 0.0);
369 
370  // Deserialize value
371  Ordinal b3 = oSerT::fromCountToDirectBytes(1);
372  const Ordinal *b2 = oSerT::convertFromCharPtr(charBuffer);
373  charBuffer += b3;
374  vs.deserialize(*b2, charBuffer, 1, &(buffer[i].val()));
375  charBuffer += *b2;
376 
377  // Deserialize derivative array
378  Ordinal b5 = oSerT::fromCountToDirectBytes(1);
379  const Ordinal *b4 = oSerT::convertFromCharPtr(charBuffer);
380  charBuffer += b5;
381  vs.deserialize(*b4, charBuffer, *my_sz,
382  &(buffer[i].fastAccessDx(0)));
383  charBuffer += *b4;
384  }
385 
386  }
387 
389 
390  };
391 
393  template <typename Ordinal, typename FadType>
394  struct SerializationTraitsImp {
395 
396  private:
397 
399  typedef typename Sacado::ValueType<FadType>::type ValueT;
400 
402  typedef Teuchos::DefaultSerializer<Ordinal,ValueT> DS;
403 
405  typedef typename DS::DefaultSerializerType ValueSerializer;
406 
408  typedef SerializationImp<Ordinal,FadType,ValueSerializer> Imp;
409 
410  public:
411 
413  static const bool supportsDirectSerialization =
414  Imp::supportsDirectSerialization;
415 
417 
418 
420  static Ordinal fromCountToIndirectBytes(const Ordinal count,
421  const FadType buffer[]) {
422  return Imp::fromCountToIndirectBytes(
423  DS::getDefaultSerializer(), count, buffer);
424  }
425 
427  static void serialize (const Ordinal count,
428  const FadType buffer[],
429  const Ordinal bytes,
430  char charBuffer[]) {
431  Imp::serialize(
432  DS::getDefaultSerializer(), count, buffer, bytes, charBuffer);
433  }
434 
436  static Ordinal fromIndirectBytesToCount(const Ordinal bytes,
437  const char charBuffer[]) {
438  return Imp::fromIndirectBytesToCount(
439  DS::getDefaultSerializer(), bytes, charBuffer);
440  }
441 
443  static void deserialize (const Ordinal bytes,
444  const char charBuffer[],
445  const Ordinal count,
446  FadType buffer[]) {
447  Imp::deserialize(
448  DS::getDefaultSerializer(), bytes, charBuffer, count, buffer);
449  }
450 
452 
453  };
454 
456  template <typename Ordinal, typename FadType>
457  struct StaticSerializationTraitsImp {
458  typedef typename Sacado::ValueType<FadType>::type ValueT;
459  typedef Teuchos::SerializationTraits<Ordinal,ValueT> vSerT;
460  typedef Teuchos::DirectSerializationTraits<Ordinal,FadType> DSerT;
461  typedef Sacado::Fad::SerializationTraitsImp<Ordinal,FadType> STI;
462 
464  static const bool supportsDirectSerialization =
465  vSerT::supportsDirectSerialization;
466 
468 
469 
471  static Ordinal fromCountToDirectBytes(const Ordinal count) {
472  return DSerT::fromCountToDirectBytes(count);
473  }
474 
476  static char* convertToCharPtr( FadType* ptr ) {
477  return DSerT::convertToCharPtr(ptr);
478  }
479 
481  static const char* convertToCharPtr( const FadType* ptr ) {
482  return DSerT::convertToCharPtr(ptr);
483  }
484 
486  static Ordinal fromDirectBytesToCount(const Ordinal bytes) {
487  return DSerT::fromDirectBytesToCount(bytes);
488  }
489 
491  static FadType* convertFromCharPtr( char* ptr ) {
492  return DSerT::convertFromCharPtr(ptr);
493  }
494 
496  static const FadType* convertFromCharPtr( const char* ptr ) {
497  return DSerT::convertFromCharPtr(ptr);
498  }
499 
501 
503 
504 
506  static Ordinal fromCountToIndirectBytes(const Ordinal count,
507  const FadType buffer[]) {
508  if (supportsDirectSerialization)
509  return DSerT::fromCountToIndirectBytes(count, buffer);
510  else
511  return STI::fromCountToIndirectBytes(count, buffer);
512  }
513 
515  static void serialize (const Ordinal count,
516  const FadType buffer[],
517  const Ordinal bytes,
518  char charBuffer[]) {
519  if (supportsDirectSerialization)
520  return DSerT::serialize(count, buffer, bytes, charBuffer);
521  else
522  return STI::serialize(count, buffer, bytes, charBuffer);
523  }
524 
526  static Ordinal fromIndirectBytesToCount(const Ordinal bytes,
527  const char charBuffer[]) {
528  if (supportsDirectSerialization)
529  return DSerT::fromIndirectBytesToCount(bytes, charBuffer);
530  else
531  return STI::fromIndirectBytesToCount(bytes, charBuffer);
532  }
533 
535  static void deserialize (const Ordinal bytes,
536  const char charBuffer[],
537  const Ordinal count,
538  FadType buffer[]) {
539  if (supportsDirectSerialization)
540  return DSerT::deserialize(bytes, charBuffer, count, buffer);
541  else
542  return STI::deserialize(bytes, charBuffer, count, buffer);
543  }
544 
546 
547  };
548 
550  template <typename Ordinal, typename FadType, typename ValueSerializer>
551  class SerializerImp {
552 
553  private:
554 
556  typedef SerializationImp<Ordinal,FadType,ValueSerializer> Imp;
557 
559  Teuchos::RCP<const ValueSerializer> vs;
560 
562  Ordinal sz;
563 
564  public:
565 
567  typedef ValueSerializer value_serializer_type;
568 
570  static const bool supportsDirectSerialization =
571  Imp::supportsDirectSerialization;
572 
574  SerializerImp(const Teuchos::RCP<const ValueSerializer>& vs_,
575  Ordinal sz_ = 0) :
576  vs(vs_), sz(sz_) {}
577 
579  Ordinal getSerializerSize() const { return sz; }
580 
582  Teuchos::RCP<const value_serializer_type> getValueSerializer() const {
583  return vs; }
584 
586 
587 
589  Ordinal fromCountToIndirectBytes(const Ordinal count,
590  const FadType buffer[]) const {
591  return Imp::fromCountToIndirectBytes(*vs, count, buffer, sz);
592  }
593 
595  void serialize (const Ordinal count,
596  const FadType buffer[],
597  const Ordinal bytes,
598  char charBuffer[]) const {
599  Imp::serialize(*vs, count, buffer, bytes, charBuffer, sz);
600  }
601 
603  Ordinal fromIndirectBytesToCount(const Ordinal bytes,
604  const char charBuffer[]) const {
605  return Imp::fromIndirectBytesToCount(*vs, bytes, charBuffer, sz);
606  }
607 
609  void deserialize (const Ordinal bytes,
610  const char charBuffer[],
611  const Ordinal count,
612  FadType buffer[]) const {
613  return Imp::deserialize(*vs, bytes, charBuffer, count, buffer, sz);
614  }
615 
617 
618  };
619 
620  } // namespace Fad
621 
622 } // namespace Sacado
623 
624 #include "Sacado_ConfigDefs.h"
625 #if defined(HAVE_SACADO_KOKKOSCORE) && defined(HAVE_SACADO_TEUCHOSKOKKOSCOMM) && defined(HAVE_SACADO_VIEW_SPEC) && !defined(SACADO_DISABLE_FAD_VIEW_SPEC)
626 
627 #include "KokkosExp_View_Fad.hpp"
628 
629 #endif
630 
631 #endif // HAVE_SACADO_TEUCHOS
632 
633 #endif // SACADO_FAD_SCALARTRAITSIMP_HPP
static std::string eval()
Sacado::Fad::DFad< double > FadType
pow(expr1.val(), expr2.val())
expr val()
sqrt(expr.val())
int Ordinal
SimpleFad< ValueT > pow(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
Sacado::Random< double > rnd