import * as tf from '@tensorflow/tfjs';

/**
* Custom layer to calculate a tensor of values which decrease in every column.
*
* Code for serialization heavily adapted from Shanqing Cai: https://gist.github.com/caisq/33ed021e0c7b9d0e728cb1dce399527d
* and from here: https://github.com/tensorflow/tfjs-examples/blob/master/custom-layer/custom_layer.js
*
* 2024-01-26, Adam Muschielok, Rodenstock GmbH
* 2024-01-30, Killian Perani, Rodenstock GmbH
*/
export class DecreasingDenseLayer extends tf.layers.Layer {
    /**
     * Create a layer to calculate a tensor of values which decrease in every column.
     *
     * @param object containing the configuration with the fields
     * - numOutputs the number of output nodes
     * - expOutputs boolean whether the output should be exponentiated (true) or not (false)
     * 
     */
    constructor(config) {
      super(config);
      this.numOutputs = config.numOutputs;
      this.expOutputs = config.expOutputs;
    }
    
    /**
     * build() is called when the custom layer object is connected to an
     * upstream layer for the first time.
     * This is where the weights are created.
     */
    build(inputShape) { 
      const inputDims = inputShape[inputShape.length - 1]
      
      this.dense_logmax_w = this.addWeight('kernel_0',
                                           [inputDims, 1], 'float32',
                                           tf.initializers.glorotNormal())
                                           
      this.dense_logmax_b = this.addWeight('bias_0',
                                           [1], 'float32',
                                           tf.initializers.zeros())
      
      this.dense_decr_w =   this.addWeight('kernel_1',
                                           [inputDims, this.numOutputs - 1], 'float32',
                                           tf.initializers.glorotNormal())

      this.dense_decr_b =   this.addWeight('bias_1',
                                           [this.numOutputs - 1], 'float32',
                                           tf.initializers.zeros())                               
    }		

    /**
     * call() contains the actual numerical computation of the layer.
     *
     * @param inputs Tensor to be treated.
     * @param kwargs Only used as a pass through to call hooks.
     *
     * If necessary use tidy() to avoid WebGL memory leak. 
     */
    call(inputs, kwargs) {
        let input = inputs;
        if (Array.isArray(input)) {
            input = input[0];
        }
        this.invokeCallHook(inputs, kwargs);
              
        // Calculate log(largest value & decrements)
        const logx_max = tf.add(tf.matMul(input, this.dense_logmax_w.read()), this.dense_logmax_b.read());
        const logx_dec = tf.neg(tf.softplus(tf.add(tf.matMul(input, this.dense_decr_w.read()), this.dense_decr_b.read())));
                      
        // Concatenate to compute the absolute log(values) wiht a cumulative sum
        const ax = logx_max.shape.length - 1;
        const logx_cat = tf.concat([logx_max, logx_dec], ax);
        let out = tf.cumsum(logx_cat, ax);
          
        // exponentiate output if necessary
        if (this.expOutputs) {
            out = tf.exp(out);
        }
          
        return out;
    }

    /**
     * getConfig() generates the JSON object that is used
     * when saving and loading the custom layer object.
     */
    getConfig() {
      const config = super.getConfig();
      Object.assign(config, {numOutputs: this.numOutputs, expOutputs: this.expOutputs});
      return config;
    }
    
    /**
     * The static className getter is required by the 
     * registration step.
     */
    static get className() {
      return 'Custom>DecreasingDenseLayer';
    }
  }