ONet.js 1.1 KB

123456789101112131415161718192021222324
  1. import * as tf from '@tensorflow/tfjs-core';
  2. import { convLayer } from '../common';
  3. import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
  4. import { prelu } from './prelu';
  5. import { sharedLayer } from './sharedLayers';
  6. export function ONet(x, params) {
  7. return tf.tidy(function () {
  8. var out = sharedLayer(x, params);
  9. out = tf.maxPool(out, [2, 2], [2, 2], 'same');
  10. out = convLayer(out, params.conv4, 'valid');
  11. out = prelu(out, params.prelu4_alpha);
  12. var vectorized = tf.reshape(out, [out.shape[0], params.fc1.weights.shape[0]]);
  13. var fc1 = fullyConnectedLayer(vectorized, params.fc1);
  14. var prelu5 = prelu(fc1, params.prelu5_alpha);
  15. var fc2_1 = fullyConnectedLayer(prelu5, params.fc2_1);
  16. var max = tf.expandDims(tf.max(fc2_1, 1), 1);
  17. var prob = tf.softmax(tf.sub(fc2_1, max), 1);
  18. var regions = fullyConnectedLayer(prelu5, params.fc2_2);
  19. var points = fullyConnectedLayer(prelu5, params.fc2_3);
  20. var scores = tf.unstack(prob, 1)[1];
  21. return { scores: scores, regions: regions, points: points };
  22. });
  23. }
  24. //# sourceMappingURL=ONet.js.map