import { Sources, SupportedLlm } from '../../models/User';
import { logDebug } from '../utils';

export interface StreamMessageBody {
  query: string;
  row_id: string;
  message_id: string;
  conversation_id: string;
  isNewTopic: boolean;
  sources?: Sources;
  llm_preference: SupportedLlm | null;
  bot_id?: string;
  tz_offset: number;
  uploaded_conversation_file_ids?: string[];
  all_contextual_files: string[];
  deepresearch: boolean;
}

interface StreamRequestEvents<T> {
  onMessage?: (message: T) => void;
  onError?: (error?: unknown) => void;
  onClose?: () => void;
}

export class QAWebSocketRequestHandler<T> {
  private ws!: WebSocket;
  private readonly onMessage?: (message: T) => void;
  private readonly onClose?: () => void;
  private readonly onError?: (error?: unknown) => void;
  private readonly isOpenPromise: Promise<void>;
  private resolveOpenPromise!: () => void;

  public constructor({ onMessage, onClose, onError }: StreamRequestEvents<T>) {
    this.onMessage = onMessage;
    this.onClose = onClose;
    this.onError = onError;

    this.isOpenPromise = new Promise<void>((resolve) => {
      this.resolveOpenPromise = resolve;
    });

    this.initWebSocket();
  }

  public close(): void {
    this.ws.close();
  }

  public async sendMessage(message: unknown): Promise<void> {
    // Wait for the WebSocket to be open
    await this.isOpenPromise;
    this.ws.send(JSON.stringify(message));
  }

  private initWebSocket(): void {
    const strippedUrl = SEARCH_URL.replace(/^(http:\/\/|https:\/\/)/, '');
    const wsProtocol = SEARCH_URL.startsWith('https') ? 'wss' : 'ws';
    this.ws = new WebSocket(`${wsProtocol}://${strippedUrl}/_chat_ws`);

    this.ws.addEventListener('open', () => {
      logDebug('WebSocket connection established');
      this.resolveOpenPromise();
    });

    this.ws.addEventListener('message', (event) => {
      try {
        // eslint-disable-next-line @typescript-eslint/no-unsafe-argument
        const message = JSON.parse(event.data as string) as T;
        if (this.onMessage) {
          this.onMessage(message);
        }
      } catch (error) {
        if (this.onError) {
          this.onError(error);
        }
      }
    });

    this.ws.addEventListener('close', () => {
      if (this.onClose) {
        this.onClose();
      }
    });

    this.ws.addEventListener('error', (e) => {
      if (this.onError) {
        this.onError(e);
      }
    });
  }
}
