predictionLayer.js 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { boxPredictionLayer } from './boxPredictionLayer';
  3. import { pointwiseConvLayer } from './pointwiseConvLayer';
  4. export function predictionLayer(x, conv11, params) {
  5. return tf.tidy(function () {
  6. var conv0 = pointwiseConvLayer(x, params.conv_0, [1, 1]);
  7. var conv1 = pointwiseConvLayer(conv0, params.conv_1, [2, 2]);
  8. var conv2 = pointwiseConvLayer(conv1, params.conv_2, [1, 1]);
  9. var conv3 = pointwiseConvLayer(conv2, params.conv_3, [2, 2]);
  10. var conv4 = pointwiseConvLayer(conv3, params.conv_4, [1, 1]);
  11. var conv5 = pointwiseConvLayer(conv4, params.conv_5, [2, 2]);
  12. var conv6 = pointwiseConvLayer(conv5, params.conv_6, [1, 1]);
  13. var conv7 = pointwiseConvLayer(conv6, params.conv_7, [2, 2]);
  14. var boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0);
  15. var boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1);
  16. var boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2);
  17. var boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3);
  18. var boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4);
  19. var boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5);
  20. var boxPredictions = tf.concat([
  21. boxPrediction0.boxPredictionEncoding,
  22. boxPrediction1.boxPredictionEncoding,
  23. boxPrediction2.boxPredictionEncoding,
  24. boxPrediction3.boxPredictionEncoding,
  25. boxPrediction4.boxPredictionEncoding,
  26. boxPrediction5.boxPredictionEncoding
  27. ], 1);
  28. var classPredictions = tf.concat([
  29. boxPrediction0.classPrediction,
  30. boxPrediction1.classPrediction,
  31. boxPrediction2.classPrediction,
  32. boxPrediction3.classPrediction,
  33. boxPrediction4.classPrediction,
  34. boxPrediction5.classPrediction
  35. ], 1);
  36. return {
  37. boxPredictions: boxPredictions,
  38. classPredictions: classPredictions
  39. };
  40. });
  41. }
  42. //# sourceMappingURL=predictionLayer.js.map