predictionLayer.js 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tf = require("@tensorflow/tfjs-core");
  4. var boxPredictionLayer_1 = require("./boxPredictionLayer");
  5. var pointwiseConvLayer_1 = require("./pointwiseConvLayer");
  6. function predictionLayer(x, conv11, params) {
  7. return tf.tidy(function () {
  8. var conv0 = pointwiseConvLayer_1.pointwiseConvLayer(x, params.conv_0, [1, 1]);
  9. var conv1 = pointwiseConvLayer_1.pointwiseConvLayer(conv0, params.conv_1, [2, 2]);
  10. var conv2 = pointwiseConvLayer_1.pointwiseConvLayer(conv1, params.conv_2, [1, 1]);
  11. var conv3 = pointwiseConvLayer_1.pointwiseConvLayer(conv2, params.conv_3, [2, 2]);
  12. var conv4 = pointwiseConvLayer_1.pointwiseConvLayer(conv3, params.conv_4, [1, 1]);
  13. var conv5 = pointwiseConvLayer_1.pointwiseConvLayer(conv4, params.conv_5, [2, 2]);
  14. var conv6 = pointwiseConvLayer_1.pointwiseConvLayer(conv5, params.conv_6, [1, 1]);
  15. var conv7 = pointwiseConvLayer_1.pointwiseConvLayer(conv6, params.conv_7, [2, 2]);
  16. var boxPrediction0 = boxPredictionLayer_1.boxPredictionLayer(conv11, params.box_predictor_0);
  17. var boxPrediction1 = boxPredictionLayer_1.boxPredictionLayer(x, params.box_predictor_1);
  18. var boxPrediction2 = boxPredictionLayer_1.boxPredictionLayer(conv1, params.box_predictor_2);
  19. var boxPrediction3 = boxPredictionLayer_1.boxPredictionLayer(conv3, params.box_predictor_3);
  20. var boxPrediction4 = boxPredictionLayer_1.boxPredictionLayer(conv5, params.box_predictor_4);
  21. var boxPrediction5 = boxPredictionLayer_1.boxPredictionLayer(conv7, params.box_predictor_5);
  22. var boxPredictions = tf.concat([
  23. boxPrediction0.boxPredictionEncoding,
  24. boxPrediction1.boxPredictionEncoding,
  25. boxPrediction2.boxPredictionEncoding,
  26. boxPrediction3.boxPredictionEncoding,
  27. boxPrediction4.boxPredictionEncoding,
  28. boxPrediction5.boxPredictionEncoding
  29. ], 1);
  30. var classPredictions = tf.concat([
  31. boxPrediction0.classPrediction,
  32. boxPrediction1.classPrediction,
  33. boxPrediction2.classPrediction,
  34. boxPrediction3.classPrediction,
  35. boxPrediction4.classPrediction,
  36. boxPrediction5.classPrediction
  37. ], 1);
  38. return {
  39. boxPredictions: boxPredictions,
  40. classPredictions: classPredictions
  41. };
  42. });
  43. }
  44. exports.predictionLayer = predictionLayer;
  45. //# sourceMappingURL=predictionLayer.js.map