mobileNetV1.js 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tf = require("@tensorflow/tfjs-core");
  4. var pointwiseConvLayer_1 = require("./pointwiseConvLayer");
  5. var epsilon = 0.0010000000474974513;
  6. function depthwiseConvLayer(x, params, strides) {
  7. return tf.tidy(function () {
  8. var out = tf.depthwiseConv2d(x, params.filters, strides, 'same');
  9. out = tf.batchNorm(out, params.batch_norm_mean, params.batch_norm_variance, params.batch_norm_offset, params.batch_norm_scale, epsilon);
  10. return tf.clipByValue(out, 0, 6);
  11. });
  12. }
  13. function getStridesForLayerIdx(layerIdx) {
  14. return [2, 4, 6, 12].some(function (idx) { return idx === layerIdx; }) ? [2, 2] : [1, 1];
  15. }
  16. function mobileNetV1(x, params) {
  17. return tf.tidy(function () {
  18. var conv11 = null;
  19. var out = pointwiseConvLayer_1.pointwiseConvLayer(x, params.conv_0, [2, 2]);
  20. var convPairParams = [
  21. params.conv_1,
  22. params.conv_2,
  23. params.conv_3,
  24. params.conv_4,
  25. params.conv_5,
  26. params.conv_6,
  27. params.conv_7,
  28. params.conv_8,
  29. params.conv_9,
  30. params.conv_10,
  31. params.conv_11,
  32. params.conv_12,
  33. params.conv_13
  34. ];
  35. convPairParams.forEach(function (param, i) {
  36. var layerIdx = i + 1;
  37. var depthwiseConvStrides = getStridesForLayerIdx(layerIdx);
  38. out = depthwiseConvLayer(out, param.depthwise_conv, depthwiseConvStrides);
  39. out = pointwiseConvLayer_1.pointwiseConvLayer(out, param.pointwise_conv, [1, 1]);
  40. if (layerIdx === 11) {
  41. conv11 = out;
  42. }
  43. });
  44. if (conv11 === null) {
  45. throw new Error('mobileNetV1 - output of conv layer 11 is null');
  46. }
  47. return {
  48. out: out,
  49. conv11: conv11
  50. };
  51. });
  52. }
  53. exports.mobileNetV1 = mobileNetV1;
  54. //# sourceMappingURL=mobileNetV1.js.map