RNet.js 894 B

12345678910111213141516171819
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
  3. import { prelu } from './prelu';
  4. import { sharedLayer } from './sharedLayers';
  5. export function RNet(x, params) {
  6. return tf.tidy(function () {
  7. var convOut = sharedLayer(x, params);
  8. var vectorized = tf.reshape(convOut, [convOut.shape[0], params.fc1.weights.shape[0]]);
  9. var fc1 = fullyConnectedLayer(vectorized, params.fc1);
  10. var prelu4 = prelu(fc1, params.prelu4_alpha);
  11. var fc2_1 = fullyConnectedLayer(prelu4, params.fc2_1);
  12. var max = tf.expandDims(tf.max(fc2_1, 1), 1);
  13. var prob = tf.softmax(tf.sub(fc2_1, max), 1);
  14. var regions = fullyConnectedLayer(prelu4, params.fc2_2);
  15. var scores = tf.unstack(prob, 1)[1];
  16. return { scores: scores, regions: regions };
  17. });
  18. }
  19. //# sourceMappingURL=RNet.js.map