residualLayer.js 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. import { __spreadArrays } from "tslib";
  2. import * as tf from '@tensorflow/tfjs-core';
  3. import { conv, convDown, convNoRelu } from './convLayer';
  4. export function residual(x, params) {
  5. var out = conv(x, params.conv1);
  6. out = convNoRelu(out, params.conv2);
  7. out = tf.add(out, x);
  8. out = tf.relu(out);
  9. return out;
  10. }
  11. export function residualDown(x, params) {
  12. var out = convDown(x, params.conv1);
  13. out = convNoRelu(out, params.conv2);
  14. var pooled = tf.avgPool(x, 2, 2, 'valid');
  15. var zeros = tf.zeros(pooled.shape);
  16. var isPad = pooled.shape[3] !== out.shape[3];
  17. var isAdjustShape = pooled.shape[1] !== out.shape[1] || pooled.shape[2] !== out.shape[2];
  18. if (isAdjustShape) {
  19. var padShapeX = __spreadArrays(out.shape);
  20. padShapeX[1] = 1;
  21. var zerosW = tf.zeros(padShapeX);
  22. out = tf.concat([out, zerosW], 1);
  23. var padShapeY = __spreadArrays(out.shape);
  24. padShapeY[2] = 1;
  25. var zerosH = tf.zeros(padShapeY);
  26. out = tf.concat([out, zerosH], 2);
  27. }
  28. pooled = isPad ? tf.concat([pooled, zeros], 3) : pooled;
  29. out = tf.add(pooled, out);
  30. out = tf.relu(out);
  31. return out;
  32. }
  33. //# sourceMappingURL=residualLayer.js.map