import { Middleware, MiddlewareResult } from './types';
import { AnyClass, ClassType } from '../types';
import { DynamicMiddlewareDispatcher } from './dynamic-middleware-dispatcher';

export type ProtoMiddlewareArgs<A extends any[] = [], T = any> = [
  inst: T,
  ...args: A,
];
export type ProtoMiddlewareFunc<A extends any[] = [], T = any> = (
  ...args: ProtoMiddlewareArgs<A, T>
) => T;

export class ProtoMiddlewareDispatcher<
  A extends any[] = [],
> extends DynamicMiddlewareDispatcher<ProtoMiddlewareFunc<A>> {
  private middlewareProtoMap = new Map<
    AnyClass,
    Middleware<ProtoMiddlewareFunc<A>>[]
  >();
  private middlewareProtoMapPrior = new Map<
    AnyClass,
    Middleware<ProtoMiddlewareFunc<A>>[]
  >();

  middleware<T>(
    cls: ClassType<T>,
    mw: Middleware<ProtoMiddlewareFunc<A, T>>,
    prior = false,
  ) {
    const map = prior ? this.middlewareProtoMapPrior : this.middlewareProtoMap;
    const mws = map.get(cls) || [];
    mws.push(mw);
    map.set(cls, mws);
    return this;
  }

  removeMiddleware<T>(
    cls: ClassType<T>,
    mw: Middleware<ProtoMiddlewareFunc<A, T>>,
  ) {
    for (const map of [this.middlewareProtoMap, this.middlewareProtoMapPrior]) {
      const mws = map.get(cls);
      if (mws) {
        const index = mws.indexOf(mw);
        if (index >= 0) {
          mws.splice(index, 1);
        }
      }
    }
    return this;
  }

  async dispatch<T>(
    ...args: ProtoMiddlewareArgs<A, T>
  ): MiddlewareResult<ProtoMiddlewareFunc<A, T>> {
    return super.dispatch(...args);
  }

  async buildMiddlewares(...args: ProtoMiddlewareArgs<A>) {
    // buildMiddlewares 只需要知道 inst
    if (args.length === 0) return [];

    const inst = args[0];
    if (!inst || typeof inst !== 'object') return [];

    // 1. 收集原型链（Base → Sub）
    const chain: AnyClass[] = [];
    let cur: any = inst.constructor;

    while (cur && cur !== Object) {
      chain.push(cur);
      cur = Object.getPrototypeOf(cur.prototype)?.constructor;
    }

    chain.reverse();

    const result: Middleware<ProtoMiddlewareFunc<A>>[] = [];

    // 2. prior：Base → Sub
    for (const cls of chain) {
      const mws = this.middlewareProtoMapPrior.get(cls);
      if (mws) {
        result.push(...[...mws].reverse());
      }
    }

    // 3. normal：Sub → Base
    for (let i = chain.length - 1; i >= 0; i--) {
      const cls = chain[i];
      const mws = this.middlewareProtoMap.get(cls);
      if (mws) {
        result.push(...mws);
      }
    }

    result.push((inst, ...args) => {
      return inst;
    });

    return result;
  }
}
