residualLayer.js 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var convLayer_1 = require("./convLayer");
  6. function residual(x, params) {
  7. var out = convLayer_1.conv(x, params.conv1);
  8. out = convLayer_1.convNoRelu(out, params.conv2);
  9. out = tf.add(out, x);
  10. out = tf.relu(out);
  11. return out;
  12. }
  13. exports.residual = residual;
  14. function residualDown(x, params) {
  15. var out = convLayer_1.convDown(x, params.conv1);
  16. out = convLayer_1.convNoRelu(out, params.conv2);
  17. var pooled = tf.avgPool(x, 2, 2, 'valid');
  18. var zeros = tf.zeros(pooled.shape);
  19. var isPad = pooled.shape[3] !== out.shape[3];
  20. var isAdjustShape = pooled.shape[1] !== out.shape[1] || pooled.shape[2] !== out.shape[2];
  21. if (isAdjustShape) {
  22. var padShapeX = tslib_1.__spreadArrays(out.shape);
  23. padShapeX[1] = 1;
  24. var zerosW = tf.zeros(padShapeX);
  25. out = tf.concat([out, zerosW], 1);
  26. var padShapeY = tslib_1.__spreadArrays(out.shape);
  27. padShapeY[2] = 1;
  28. var zerosH = tf.zeros(padShapeY);
  29. out = tf.concat([out, zerosH], 2);
  30. }
  31. pooled = isPad ? tf.concat([pooled, zeros], 3) : pooled;
  32. out = tf.add(pooled, out);
  33. out = tf.relu(out);
  34. return out;
  35. }
  36. exports.residualDown = residualDown;
  37. //# sourceMappingURL=residualLayer.js.map