boxPredictionLayer.js 604 B

1234567891011121314
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { convLayer } from '../common';
  3. export function boxPredictionLayer(x, params) {
  4. return tf.tidy(function () {
  5. var batchSize = x.shape[0];
  6. var boxPredictionEncoding = tf.reshape(convLayer(x, params.box_encoding_predictor), [batchSize, -1, 1, 4]);
  7. var classPrediction = tf.reshape(convLayer(x, params.class_predictor), [batchSize, -1, 3]);
  8. return {
  9. boxPredictionEncoding: boxPredictionEncoding,
  10. classPrediction: classPrediction
  11. };
  12. });
  13. }
  14. //# sourceMappingURL=boxPredictionLayer.js.map