import { ConflictException, Injectable, NotFoundException } from '@nestjs/common';
import { and, asc, eq, inArray, sql } from 'drizzle-orm';
import { DatabaseService } from '../../common/database/database.service';
import { tradeAdditions, tradePartials, trades, tradingAccounts } from '../../database/schema';
import type {
  CreateTradeDto,
  TradeAdditionDto,
  TradePartialDto,
  UpdateTradeDto,
} from './dto/trade.dto';
import type { CreateTradingAccountDto, UpdateTradingAccountDto } from './dto/trading-account.dto';

@Injectable()
export class TradingRepository {
  constructor(private readonly database: DatabaseService) {}

  listAccounts(ownerId: string) {
    return this.database.db
      .select()
      .from(tradingAccounts)
      .where(eq(tradingAccounts.ownerId, ownerId))
      .orderBy(asc(tradingAccounts.createdAt));
  }

  async createAccount(ownerId: string, input: CreateTradingAccountDto) {
    const [account] = await this.database.db
      .insert(tradingAccounts)
      .values({
        ownerId,
        name: input.name.trim(),
        color: input.color ?? null,
        exchange: input.exchange?.trim() || null,
        description: input.description?.trim() || null,
      })
      .returning();
    return account;
  }

  async updateAccount(ownerId: string, id: string, input: UpdateTradingAccountDto) {
    const [account] = await this.database.db
      .update(tradingAccounts)
      .set({
        ...(input.name !== undefined ? { name: input.name.trim() } : {}),
        ...(input.color !== undefined ? { color: input.color || null } : {}),
        ...(input.exchange !== undefined ? { exchange: input.exchange.trim() || null } : {}),
        ...(input.description !== undefined
          ? { description: input.description.trim() || null }
          : {}),
        version: sql`${tradingAccounts.version} + 1`,
        updatedAt: new Date(),
      })
      .where(
        and(
          eq(tradingAccounts.id, id),
          eq(tradingAccounts.ownerId, ownerId),
          eq(tradingAccounts.version, input.version),
        ),
      )
      .returning();

    if (!account) await this.throwNotFoundOrConflict(ownerId, id, input.version);
    return account;
  }

  async deleteAccount(ownerId: string, id: string): Promise<void> {
    const deleted = await this.database.db
      .delete(tradingAccounts)
      .where(and(eq(tradingAccounts.id, id), eq(tradingAccounts.ownerId, ownerId)))
      .returning({ id: tradingAccounts.id });
    if (!deleted.length) throw new NotFoundException('Trading account not found.');
  }

  async listTrades(ownerId: string, accountId: string) {
    await this.requireAccount(ownerId, accountId);
    const rows = await this.database.db
      .select()
      .from(trades)
      .where(eq(trades.accountId, accountId))
      .orderBy(asc(trades.openDate), asc(trades.createdAt));
    return this.hydrateTrades(rows);
  }

  async createTrade(ownerId: string, accountId: string, input: CreateTradeDto) {
    await this.requireAccount(ownerId, accountId);
    return this.database.db.transaction(async (tx) => {
      const [trade] = await tx
        .insert(trades)
        .values(this.tradeValues(accountId, input))
        .returning();
      await this.replaceChildren(tx, trade.id, input.adds, input.partials);
      const [hydrated] = await this.hydrateTrades([trade], tx);
      return hydrated;
    });
  }

  async updateTrade(ownerId: string, id: string, input: UpdateTradeDto) {
    await this.requireOwnedTrade(ownerId, id);

    return this.database.db.transaction(async (tx) => {
      const [trade] = await tx
        .update(trades)
        .set({
          ...this.tradePatchValues(input),
          version: sql`${trades.version} + 1`,
          updatedAt: new Date(),
        })
        .where(and(eq(trades.id, id), eq(trades.version, input.version)))
        .returning();

      if (!trade) {
        throw new ConflictException(
          `Trade was changed by another request (expected version ${input.version}).`,
        );
      }

      if (input.adds !== undefined || input.partials !== undefined) {
        await this.replaceChildren(
          tx,
          id,
          input.adds,
          input.partials,
          input.adds !== undefined,
          input.partials !== undefined,
        );
      }

      const [hydrated] = await this.hydrateTrades([trade], tx);
      return hydrated;
    });
  }

  async deleteTrade(ownerId: string, id: string): Promise<void> {
    await this.requireOwnedTrade(ownerId, id);
    await this.database.db.delete(trades).where(eq(trades.id, id));
  }

  private async hydrateTrades(
    rows: (typeof trades.$inferSelect)[],
    database: Pick<DatabaseService['db'], 'select'> = this.database.db,
  ) {
    if (!rows.length) return [];
    const ids = rows.map((row) => row.id);
    const additions = await database
      .select()
      .from(tradeAdditions)
      .where(inArray(tradeAdditions.tradeId, ids))
      .orderBy(asc(tradeAdditions.date), asc(tradeAdditions.createdAt));
    const partials = await database
      .select()
      .from(tradePartials)
      .where(inArray(tradePartials.tradeId, ids))
      .orderBy(asc(tradePartials.date), asc(tradePartials.createdAt));

    return rows.map((trade) => ({
      id: trade.id,
      accountId: trade.accountId,
      openDate: trade.openDate,
      closeDate: trade.closeDate,
      asset: trade.asset,
      dir: trade.direction,
      tradeType: trade.tradeType,
      result: trade.result,
      entry: trade.entry,
      exit: trade.exit,
      sl: trade.stopLoss,
      tp: trade.takeProfit,
      liqPrice: trade.liquidationPrice,
      lev: trade.leverage,
      capital: trade.capital,
      qty: trade.quantity,
      baseQty: trade.baseQuantity,
      baseEntry: trade.baseEntry,
      baseCapital: trade.baseCapital,
      pnl: trade.pnl,
      setup: trade.setup,
      emotion: trade.emotion,
      notes: trade.notes,
      version: trade.version,
      createdAt: trade.createdAt,
      updatedAt: trade.updatedAt,
      adds: additions
        .filter((addition) => addition.tradeId === trade.id)
        .map((addition) => ({
          id: addition.id,
          date: addition.date,
          qty: addition.quantity,
          price: addition.price,
          margin: addition.margin,
        })),
      partials: partials
        .filter((partial) => partial.tradeId === trade.id)
        .map((partial) => ({
          id: partial.id,
          date: partial.date,
          qty: partial.quantity,
          price: partial.price,
          pnl: partial.pnl,
          note: partial.note,
        })),
    }));
  }

  private tradeValues(accountId: string, input: CreateTradeDto) {
    return {
      accountId,
      openDate: input.openDate,
      closeDate: input.closeDate ?? null,
      asset: input.asset,
      direction: input.dir,
      tradeType: input.tradeType,
      result: input.result,
      entry: input.entry ?? null,
      exit: input.exit ?? null,
      stopLoss: input.sl ?? null,
      takeProfit: input.tp ?? null,
      liquidationPrice: input.liqPrice ?? null,
      leverage: input.lev ?? null,
      capital: input.capital ?? null,
      quantity: input.qty ?? null,
      baseQuantity: input.baseQty ?? null,
      baseEntry: input.baseEntry ?? null,
      baseCapital: input.baseCapital ?? null,
      pnl: input.pnl ?? null,
      setup: input.setup?.trim() || null,
      emotion: input.emotion?.trim() || null,
      notes: input.notes?.trim() || null,
    };
  }

  private tradePatchValues(input: UpdateTradeDto) {
    return {
      ...(input.openDate !== undefined ? { openDate: input.openDate } : {}),
      ...(input.closeDate !== undefined ? { closeDate: input.closeDate } : {}),
      ...(input.asset !== undefined ? { asset: input.asset } : {}),
      ...(input.dir !== undefined ? { direction: input.dir } : {}),
      ...(input.tradeType !== undefined ? { tradeType: input.tradeType } : {}),
      ...(input.result !== undefined ? { result: input.result } : {}),
      ...(input.entry !== undefined ? { entry: input.entry } : {}),
      ...(input.exit !== undefined ? { exit: input.exit } : {}),
      ...(input.sl !== undefined ? { stopLoss: input.sl } : {}),
      ...(input.tp !== undefined ? { takeProfit: input.tp } : {}),
      ...(input.liqPrice !== undefined ? { liquidationPrice: input.liqPrice } : {}),
      ...(input.lev !== undefined ? { leverage: input.lev } : {}),
      ...(input.capital !== undefined ? { capital: input.capital } : {}),
      ...(input.qty !== undefined ? { quantity: input.qty } : {}),
      ...(input.baseQty !== undefined ? { baseQuantity: input.baseQty } : {}),
      ...(input.baseEntry !== undefined ? { baseEntry: input.baseEntry } : {}),
      ...(input.baseCapital !== undefined ? { baseCapital: input.baseCapital } : {}),
      ...(input.pnl !== undefined ? { pnl: input.pnl } : {}),
      ...(input.setup !== undefined ? { setup: input.setup.trim() || null } : {}),
      ...(input.emotion !== undefined ? { emotion: input.emotion.trim() || null } : {}),
      ...(input.notes !== undefined ? { notes: input.notes.trim() || null } : {}),
    };
  }

  private async replaceChildren(
    tx: Parameters<Parameters<DatabaseService['db']['transaction']>[0]>[0],
    tradeId: string,
    additions: TradeAdditionDto[] | undefined,
    partials: TradePartialDto[] | undefined,
    replaceAdditions = true,
    replacePartials = true,
  ): Promise<void> {
    if (replaceAdditions) {
      await tx.delete(tradeAdditions).where(eq(tradeAdditions.tradeId, tradeId));
      if (additions?.length) {
        await tx.insert(tradeAdditions).values(
          additions.map((addition) => ({
            tradeId,
            date: addition.date,
            quantity: addition.quantity ?? null,
            price: addition.price ?? null,
            margin: addition.margin ?? null,
          })),
        );
      }
    }

    if (replacePartials) {
      await tx.delete(tradePartials).where(eq(tradePartials.tradeId, tradeId));
      if (partials?.length) {
        await tx.insert(tradePartials).values(
          partials.map((partial) => ({
            tradeId,
            date: partial.date,
            quantity: partial.quantity ?? null,
            price: partial.price ?? null,
            pnl: partial.pnl ?? null,
            note: partial.note?.trim() || null,
          })),
        );
      }
    }
  }

  private async requireAccount(ownerId: string, id: string) {
    const [account] = await this.database.db
      .select()
      .from(tradingAccounts)
      .where(and(eq(tradingAccounts.id, id), eq(tradingAccounts.ownerId, ownerId)))
      .limit(1);
    if (!account) throw new NotFoundException('Trading account not found.');
    return account;
  }

  private async requireOwnedTrade(ownerId: string, id: string) {
    const [trade] = await this.database.db
      .select({ trade: trades })
      .from(trades)
      .innerJoin(tradingAccounts, eq(tradingAccounts.id, trades.accountId))
      .where(and(eq(trades.id, id), eq(tradingAccounts.ownerId, ownerId)))
      .limit(1);
    if (!trade) throw new NotFoundException('Trade not found.');
    return trade.trade;
  }

  private async throwNotFoundOrConflict(
    ownerId: string,
    id: string,
    version: number,
  ): Promise<never> {
    const account = await this.requireAccount(ownerId, id);
    throw new ConflictException(
      `Trading account was changed by another request (expected version ${version}, current version ${account.version}).`,
    );
  }
}
