mobileNetV1.js 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { pointwiseConvLayer } from './pointwiseConvLayer';
  3. var epsilon = 0.0010000000474974513;
  4. function depthwiseConvLayer(x, params, strides) {
  5. return tf.tidy(function () {
  6. var out = tf.depthwiseConv2d(x, params.filters, strides, 'same');
  7. out = tf.batchNorm(out, params.batch_norm_mean, params.batch_norm_variance, params.batch_norm_offset, params.batch_norm_scale, epsilon);
  8. return tf.clipByValue(out, 0, 6);
  9. });
  10. }
  11. function getStridesForLayerIdx(layerIdx) {
  12. return [2, 4, 6, 12].some(function (idx) { return idx === layerIdx; }) ? [2, 2] : [1, 1];
  13. }
  14. export function mobileNetV1(x, params) {
  15. return tf.tidy(function () {
  16. var conv11 = null;
  17. var out = pointwiseConvLayer(x, params.conv_0, [2, 2]);
  18. var convPairParams = [
  19. params.conv_1,
  20. params.conv_2,
  21. params.conv_3,
  22. params.conv_4,
  23. params.conv_5,
  24. params.conv_6,
  25. params.conv_7,
  26. params.conv_8,
  27. params.conv_9,
  28. params.conv_10,
  29. params.conv_11,
  30. params.conv_12,
  31. params.conv_13
  32. ];
  33. convPairParams.forEach(function (param, i) {
  34. var layerIdx = i + 1;
  35. var depthwiseConvStrides = getStridesForLayerIdx(layerIdx);
  36. out = depthwiseConvLayer(out, param.depthwise_conv, depthwiseConvStrides);
  37. out = pointwiseConvLayer(out, param.pointwise_conv, [1, 1]);
  38. if (layerIdx === 11) {
  39. conv11 = out;
  40. }
  41. });
  42. if (conv11 === null) {
  43. throw new Error('mobileNetV1 - output of conv layer 11 is null');
  44. }
  45. return {
  46. out: out,
  47. conv11: conv11
  48. };
  49. });
  50. }
  51. //# sourceMappingURL=mobileNetV1.js.map