import invariant from 'invariant';

import { JwtPayload, jwtDecode } from 'jwt-decode';
import nullthrows from 'nullthrows';
import { Store } from 'redux';
import Session from '../common/utils/Session';
import { receivedChannelMessage } from './actions';

class ChannelListener {
  private store_: Store<any> | null = null;
  private session_: Session | null = null;

  private ws_: WebSocket | null = null;
  private subscriptions_: Set<string> = new Set();

  constructor() {}

  async sendChatMessage(room: string, message: string): Promise<void> {
    if (!this.ws_) {
      return;
    }
    this.ws_.send(
      JSON.stringify({
        type: 'chat',
        room,
        message,
      }),
    );
  }

  async subscribeToTopic(topic: string): Promise<void> {
    if (this.subscriptions_.has(topic)) {
      return;
    }

    this.subscriptions_.add(topic);

    if (this.ws_) {
      this._sendSubscribe(this.ws_, topic);
    }
  }
  async unsubscribeToTopic(topic: string): Promise<void> {
    if (!this.subscriptions_.has(topic)) {
      return;
    }
    this.subscriptions_.delete(topic);

    if (this.ws_) {
      this._sendUnsubscribe(this.ws_, topic);
    }
  }

  private _sendSubscribe(ws: WebSocket, topic: string): void {
    ws.send(
      JSON.stringify({
        type: 'subscribe',
        topic,
      }),
    );
  }

  private _sendUnsubscribe(ws: WebSocket, topic: string): void {
    ws.send(
      JSON.stringify({
        type: 'unsubscribe',
        topic,
      }),
    );
  }

  private _onWSError = (e: Event) => {
    console.error('websocket error', e);
  };
  private _onWSMessage = (e: MessageEvent) => {
    var data: any = e.data;
    var message = JSON.parse(data);

    // @ts-ignore
    this.store_!.dispatch(receivedChannelMessage(message));
  };

  private _onWSClose = async (e: Event) => {
    console.error('websocket close', e);
    this.ws_ = null;
    this._connect();
  };

  private async _connect(): Promise<void> {
    const jwt = await this._genJWT();
    const url = this._getURL(jwt);

    const ws = new WebSocket(url);
    ws.onmessage = this._onWSMessage;
    ws.onerror = this._onWSError;
    ws.onclose = this._onWSClose;

    await new Promise<void>((resolve, reject) => {
      ws.onopen = () => {
        console.log('websocket connected');
        resolve();
      };
      ws.onerror = (e) => {
        reject(e);
      };
    });

    this.ws_ = ws;

    for (let topic of this.subscriptions_) {
      this._sendSubscribe(ws, topic);
    }
  }

  private _getURL(token: string | null): string {
    let protocol = 'ws:';
    let port = ':3002';
    if (window.location.protocol === 'https:') {
      protocol = 'wss:';
      port = '';
    }
    let url = protocol + '//' + window.location.hostname + port + '/channel';
    if (token) {
      url += `?jwt=${token}`;
    }
    return url;
  }

  private async _genJWT(): Promise<string | null> {
    if (!this.session_) {
      return null;
    }
    return this._genFetchJWT();
  }

  private async _genFetchJWT(): Promise<string> {
    if (window && window.localStorage) {
      const cachedJWT = window.localStorage.getItem('jwt');
      if (this._isValidJWT(cachedJWT)) {
        return cachedJWT!;
      }
    }

    const response = await fetch('/auth/jwt', { credentials: 'same-origin' });
    const jwt = await response.text();
    if (!this._isValidJWT(jwt)) {
      throw new Error('Invalid JWT');
    }
    if (window && window.localStorage) {
      window.localStorage.setItem('jwt', jwt);
    }
    return jwt;
  }
  private _isValidJWT(token: string | null): boolean {
    if (!token) {
      return false;
    }
    let decoded: JwtPayload | null = null;
    try {
      decoded = jwtDecode(token);
    } catch (e) {
      return false;
    }
    if (!decoded) {
      return false;
    }

    return decoded.sub === this.session_?.getUserID();
  }

  start(store: Store<any>, session: Session | null): this {
    invariant(!this.store_, 'already started');

    this.store_ = store;
    this.session_ = session || null;
    this._connect();

    return this;
  }
}

let channelListener: ChannelListener | null = null;
export function startChannelListener(
  store: Store<any>,
  session: Session | null,
): ChannelListener {
  if (!channelListener) {
    channelListener = new ChannelListener().start(store, session);
  }
  return channelListener;
}

export function getChannelListener(): ChannelListener {
  return nullthrows(channelListener);
}
