import { MultiStageGenerationInput } from "@/backend/custom-model-post-process";
import { getUpdaterFunction, SetEditorStateFunction } from "@/contexts/editor-context-utils";
import { PublicUserId, PublicUserRoles, StateUpdater } from "@/core/common/types";
import { debugError, debugLog } from "@/core/utils/print-utilts";
import { Timestamp } from "firebase/firestore";
import { clamp, noop } from "lodash";
import {
  CustomModelPredictionInputInstant,
  isCustomModelPredictionInputInstant,
} from "./custom-model-predition-instant";
import { CustomModelScaleConfigs, isCustomModelScaleConfigs } from "./custom-model-scale-configs";
import { PublicTeamId } from "./team";
export type { CustomModelScaleConfig, CustomModelScaleConfigs } from "./custom-model-scale-configs";

export enum CustomModelEditorTab {
  Train = "train",
  Generate = "generate",
}

export type CustomModelDatasetItemStatus = "queued" | "processing" | "finished";

export type CustomModelDataset = Record<string, CustomModelDatasetItem>;

export enum DocVisibility {
  Public = "public",
  Private = "private",
}

export enum CustomModelType {
  Product = "Product",
  Fashion = "Fashion",
  Style = "Style",
  Face = "Face",
  Custom = "Custom",
  Furniture = "Furniture",
  Tech = "Tech",
  Food = "Food",
  Vase = "Vase",
  VirtualModel = "VirtualModel",
  Footwear = "Footwear",
  Jewelry = "Jewelry",
  Bags = "Bags",
  BrandA = "BrandA",
  BrandB = "BrandB",
}

const productCustomModelType = new Set([
  CustomModelType.Product,
  CustomModelType.Fashion,
  CustomModelType.Furniture,
  CustomModelType.Tech,
  CustomModelType.Food,
  CustomModelType.Vase,
  CustomModelType.Jewelry,
  CustomModelType.Bags,
]);

const humanCustomModelType = new Set([CustomModelType.VirtualModel]);

export function isCustomModelTypeProduct(type: CustomModelType | undefined | null) {
  return type && productCustomModelType.has(type);
}

export function isCustomModelTypeHuman(type: CustomModelType | undefined | null) {
  return type && humanCustomModelType.has(type);
}

export enum FrontendDisplayTemplateType {
  Product = "Product",
  Fashion = "Fashion",
  Style = "Style",
  Face = "Face",
  Custom = "Custom",
  Furniture = "Furniture",
  Tech = "Tech",
  Food = "Food",
  Vase = "Vase",
  VirtualModel = "VirtualModel",
  Footwear = "Footwear",
  Jewelry = "Jewelry",
  Bags = "Bags",
  BrandA = "BrandA",
  BrandB = "BrandB",
  HomeDecor = "HomeDecor",
  Toys = "Toys",
  Hats = "Hats",
  OfficeSupplies = "OfficeSupplies",
  Kitchenware = "Kitchenware",
  Vehicles = "Vehicles",
  MensFashion = "MensFashion",
  Dresses = "Dresses",
  Outerwear = "Outerwear",
  Glasses = "Glasses",
  Watches = "Watches",
  Sweaters = "Sweaters",
}

const frontendDisplayTemplateTypeToCustomModelTypeMap: Record<
  FrontendDisplayTemplateType,
  CustomModelType
> = {
  [FrontendDisplayTemplateType.Product]: CustomModelType.Product,
  [FrontendDisplayTemplateType.Fashion]: CustomModelType.Fashion,
  [FrontendDisplayTemplateType.Style]: CustomModelType.Style,
  [FrontendDisplayTemplateType.Face]: CustomModelType.Face,
  [FrontendDisplayTemplateType.Custom]: CustomModelType.Custom,
  [FrontendDisplayTemplateType.Furniture]: CustomModelType.Furniture,
  [FrontendDisplayTemplateType.Tech]: CustomModelType.Tech,
  [FrontendDisplayTemplateType.Food]: CustomModelType.Food,
  [FrontendDisplayTemplateType.Vase]: CustomModelType.Vase,
  [FrontendDisplayTemplateType.VirtualModel]: CustomModelType.VirtualModel,
  [FrontendDisplayTemplateType.Footwear]: CustomModelType.Footwear,
  [FrontendDisplayTemplateType.Jewelry]: CustomModelType.Jewelry,
  [FrontendDisplayTemplateType.Bags]: CustomModelType.Bags,
  [FrontendDisplayTemplateType.BrandA]: CustomModelType.BrandA,
  [FrontendDisplayTemplateType.BrandB]: CustomModelType.BrandB,
  [FrontendDisplayTemplateType.HomeDecor]: CustomModelType.Product,
  [FrontendDisplayTemplateType.Toys]: CustomModelType.Product,
  [FrontendDisplayTemplateType.Hats]: CustomModelType.Fashion,
  [FrontendDisplayTemplateType.OfficeSupplies]: CustomModelType.Product,
  [FrontendDisplayTemplateType.Kitchenware]: CustomModelType.Product,
  [FrontendDisplayTemplateType.Vehicles]: CustomModelType.Product,
  [FrontendDisplayTemplateType.MensFashion]: CustomModelType.Fashion,
  [FrontendDisplayTemplateType.Dresses]: CustomModelType.Fashion,
  [FrontendDisplayTemplateType.Outerwear]: CustomModelType.Fashion,
  [FrontendDisplayTemplateType.Glasses]: CustomModelType.Fashion,
  [FrontendDisplayTemplateType.Watches]: CustomModelType.Fashion,
  [FrontendDisplayTemplateType.Sweaters]: CustomModelType.Fashion,
};

export const customModelTypeToFrontEndDisplayTemplateTypeMap: Record<
  CustomModelType,
  FrontendDisplayTemplateType
> = {
  [CustomModelType.Product]: FrontendDisplayTemplateType.Product,
  [CustomModelType.Fashion]: FrontendDisplayTemplateType.Fashion,
  [CustomModelType.Style]: FrontendDisplayTemplateType.Style,
  [CustomModelType.Face]: FrontendDisplayTemplateType.Face,
  [CustomModelType.Custom]: FrontendDisplayTemplateType.Custom,
  [CustomModelType.Furniture]: FrontendDisplayTemplateType.Furniture,
  [CustomModelType.Tech]: FrontendDisplayTemplateType.Tech,
  [CustomModelType.Food]: FrontendDisplayTemplateType.Food,
  [CustomModelType.Vase]: FrontendDisplayTemplateType.Vase,
  [CustomModelType.VirtualModel]: FrontendDisplayTemplateType.VirtualModel,
  [CustomModelType.Footwear]: FrontendDisplayTemplateType.Footwear,
  [CustomModelType.Jewelry]: FrontendDisplayTemplateType.Jewelry,
  [CustomModelType.Bags]: FrontendDisplayTemplateType.Bags,
  [CustomModelType.BrandA]: FrontendDisplayTemplateType.BrandA,
  [CustomModelType.BrandB]: FrontendDisplayTemplateType.BrandB,
};

export function isFrontendDisplayTemplateType(value: any): value is FrontendDisplayTemplateType {
  return Object.values(FrontendDisplayTemplateType).includes(value);
}

export function isCustomModelType(value: any): value is CustomModelType {
  return Object.values(CustomModelType).includes(value);
}

export function getCustomModelTypeFromFrontendDisplayTemplateType(
  displayTemplateType: FrontendDisplayTemplateType | CustomModelType,
): CustomModelType {
  if (isCustomModelType(displayTemplateType)) {
    return displayTemplateType;
  }
  return frontendDisplayTemplateTypeToCustomModelTypeMap[displayTemplateType];
}

export function getFrontendDisplayTemplateTypeFromCustomModelType(
  customModelType: CustomModelType,
) {
  return customModelTypeToFrontEndDisplayTemplateTypeMap[customModelType];
}

export function getCustomModelWorkflowFromCustomModelInfo(customModelInfo: {
  frontendDisplayTemplateType?: FrontendDisplayTemplateType;
  customModelType?: CustomModelType;
}): FrontendDisplayTemplateType {
  return (
    customModelInfo?.frontendDisplayTemplateType ||
    (customModelInfo?.customModelType &&
      getFrontendDisplayTemplateTypeFromCustomModelType(customModelInfo?.customModelType)) ||
    FrontendDisplayTemplateType.Custom
  );
}

export enum CustomModelTrainingContentOrStyle {
  Content = "content",
  Style = "style",
  Balanced = "balanced",
}

export interface CustomModelInfo {
  id: string;
  customModelType?: CustomModelType;
  frontendDisplayTemplateType?: FrontendDisplayTemplateType;
  displayName?: string;
  timeCreated: Timestamp;
  timeModified: Timestamp;
  isDeleted: boolean;
  publicTeamId: PublicTeamId;
  roles: PublicUserRoles;
  thumbnailStoragePath?: string;
  defaultPredictionItem?: Pick<CustomModelPredictionItem, "input" | "output" | "usedModels">;
  visibility: DocVisibility;
}

export function isCustomModelInfo(obj: any): obj is CustomModelInfo {
  return (
    obj != null &&
    typeof obj.id === "string" &&
    (typeof obj.displayName === "undefined" || typeof obj.displayName === "string") &&
    typeof obj.roles === "object" &&
    obj.roles != null
  );
}

export enum MultiDatasetTag {
  Product = "Product",
  Human = "Human",
  Other = "Other",
}

export interface CustomModelDatasetItem {
  id: string;
  caption?: string;
  storagePath: string;
  timeCreated: Timestamp;
  timeModified: Timestamp;
  multiDatasetTag?: MultiDatasetTag;
}

export function isCustomModelDatasetItem(obj: any): obj is CustomModelDatasetItem {
  return (
    typeof obj === "object" &&
    obj !== null &&
    typeof obj.id === "string" &&
    typeof obj.storagePath === "string"
  );
}

export enum CustomModelTrainingStatus {
  Starting = "starting",
  Processing = "processing",
  GenericHardwareCheckpointPaused = "paused",
  Succeeded = "succeeded",
  Failed = "failed",
  Canceled = "canceled",
}

const activeCustomModelTrainingStatuses = new Set([
  CustomModelTrainingStatus.Processing,
  CustomModelTrainingStatus.Starting,
  CustomModelTrainingStatus.GenericHardwareCheckpointPaused,
]);

export function isCustomModelTrainingStatusActive(status?: CustomModelTrainingStatus) {
  return status && activeCustomModelTrainingStatuses.has(status);
}

export enum CustomModelTrainingBackendType {
  GenericHardware = "R", // This used to be Replicate but now we support many types of hardware
  Fal = "F",
  Instant = "Instant",
}

export function isHighQualityTrainingBackendType(backendType: CustomModelTrainingBackendType) {
  return backendType === CustomModelTrainingBackendType.GenericHardware;
}

export function isFastTrainingBackendType(backendType: CustomModelTrainingBackendType) {
  return backendType === CustomModelTrainingBackendType.Fal;
}

export const customModelTrainingBackendToDisplayName: Record<
  CustomModelTrainingBackendType,
  string
> = {
  [CustomModelTrainingBackendType.GenericHardware]: "Standard",
  [CustomModelTrainingBackendType.Fal]: "Fast",
  [CustomModelTrainingBackendType.Instant]: "Instant",
};

export function getCustomModelTrainingBackendDisplayName(training: CustomModelTrainingItem) {
  try {
    return customModelTrainingBackendToDisplayName[
      training.input?.backendType || CustomModelTrainingBackendType.Fal
    ];
  } catch (error) {
    debugError("[getCustomModelTrainingBackendDisplayName] Error: ", error);
    return customModelTrainingBackendToDisplayName[CustomModelTrainingBackendType.Fal];
  }
}

export interface CustomModelTrainingInputFalSpecific {
  backendType: CustomModelTrainingBackendType.Fal;
  trigger_word: string;
  iter_multiplier?: number;
  is_style?: boolean;
  is_input_format_already_preprocessed?: boolean;
  modelId?: string;
  trainingStrengthPercent: number;
}

export function isCustomModelTrainingInputFalSpecific(
  obj: any,
): obj is CustomModelTrainingInputFalSpecific {
  return (
    obj != null &&
    typeof obj.trigger_word === "string" &&
    obj.backendType === CustomModelTrainingBackendType.Fal
  );
}

/** used to have other types but they all got merged into generic hardware */
export interface CustomModelTrainingInputGenericHardwareSpecific {
  backendType: CustomModelTrainingBackendType.GenericHardware;
  steps: number;
  batch_size: number;
  autocaption?: boolean;
  trigger_word: string;
  learning_rate: number;
  content_or_style?: CustomModelTrainingContentOrStyle; // does nothing on replicate but apparently the ostris trainer still has this. content for is for products and presumably style is for general vibes
  trainingId?: string;
  trainingStrengthPercent: number;
  modelId?: string; // not used, just makes the types work.
}

export function isCustomModelTrainingInputGenericHardwareSpecific(
  obj: any,
): obj is CustomModelTrainingInputGenericHardwareSpecific {
  return obj != null && obj.backendType === CustomModelTrainingBackendType.GenericHardware;
}

export interface CustomModelTrainingInputInstant {
  backendType: CustomModelTrainingBackendType.Instant;
  trigger_word: string;
  modelId?: string;
}

export function isCustomModelTrainingInputInstant(
  value: any,
): value is CustomModelTrainingInputInstant {
  return (
    value !== null &&
    typeof value === "object" &&
    "backendType" in value &&
    value.backendType === CustomModelTrainingBackendType.Instant &&
    "trigger_word" in value &&
    typeof value.trigger_word === "string"
  );
}

export type CustomModelTrainingInput =
  | CustomModelTrainingInputFalSpecific
  | CustomModelTrainingInputGenericHardwareSpecific
  | CustomModelTrainingInputInstant;

export function isCustomModelTrainingInput(obj: any): obj is CustomModelTrainingInput {
  return (
    isCustomModelTrainingInputFalSpecific(obj) ||
    isCustomModelTrainingInputGenericHardwareSpecific(obj) ||
    isCustomModelTrainingInputInstant(obj)
  );
}

export function getDefaultCustomModelTrainingInput(): CustomModelTrainingInput {
  return {
    backendType: CustomModelTrainingBackendType.GenericHardware,
    steps: 1000,
    batch_size: 1,
    autocaption: true,
    trigger_word: "TOK",
    learning_rate: 0.0004,
    trainingStrengthPercent: 0.5,
  };
}

/**
 * Normalizes model scale configurations:
 * - Single model: preserve original scale
 * - Multiple models: normalize only if total scale exceeds normalizationTarget
 * @param scaleConfigs
 * @param normalizationTarget Optional parameter for the target normalization value (default: 2.0)
 * @returns
 */
export function correctCustomModelScaleConfigs(
  scaleConfigs: Partial<CustomModelScaleConfigs>,
  normalizationTarget: number = 2.0,
): CustomModelScaleConfigs {
  try {
    debugLog("Correcting custom model scale configs:", { scaleConfigs, normalizationTarget });

    // First pass: validate configs and filter out invalid ones
    const validConfigs = Object.entries(scaleConfigs).reduce((acc, [key, partialConfig]) => {
      try {
        if (!partialConfig) {
          debugLog(`Skipping null/undefined config for key: ${key}`);
          return acc;
        }

        const { modelId, trainingId } = partialConfig;
        let { scale } = partialConfig;

        debugLog(`Processing config for key ${key}:`, { modelId, trainingId, scale });

        // Validate modelId and trainingId
        if (typeof modelId !== "string" || !modelId.trim()) {
          debugLog(`Invalid modelId for key ${key}: ${modelId}`);
          return acc;
        }
        if (typeof trainingId !== "string" || !trainingId.trim()) {
          debugLog(`Invalid trainingId for key ${key}: ${trainingId}`);
          return acc;
        }

        // Validate scale (default to 1 if invalid)
        if (typeof scale !== "number" || isNaN(scale)) {
          debugLog(`Invalid scale for key ${key}, defaulting to 1: ${scale}`);
          scale = 1;
        }

        // Discard if scale is 0
        if (scale === 0) {
          debugLog(`Discarding config with zero scale for key ${key}`);
          return acc;
        }

        // Ensure scale is positive
        const originalScale = scale;
        scale = Math.max(0.001, scale);
        if (scale !== originalScale) {
          debugLog(
            `Adjusted scale to ensure positive value for key ${key}: ${originalScale} -> ${scale}`,
          );
        }

        acc[key] = { scale, modelId, trainingId };
        return acc;
      } catch (error) {
        debugError(`Error processing config for key ${key}:`, error);
        return acc;
      }
    }, {} as CustomModelScaleConfigs);

    // Second pass: normalize scales if needed
    const totalScale = Object.values(validConfigs).reduce((sum, config) => sum + config.scale, 0);

    // If only one model or total scale is below target, return as is
    if (Object.keys(validConfigs).length <= 1 || totalScale <= normalizationTarget) {
      debugLog("No normalization needed, returning valid configs:", validConfigs);
      return validConfigs;
    }

    // Normalize scales to sum to normalizationTarget
    const normalizationFactor = normalizationTarget / totalScale;
    const normalizedConfigs = Object.entries(validConfigs).reduce((acc, [key, config]) => {
      acc[key] = { ...config, scale: config.scale * normalizationFactor };
      return acc;
    }, {} as CustomModelScaleConfigs);

    debugLog("Normalized configs:", normalizedConfigs);
    return normalizedConfigs;
  } catch (error) {
    debugError("Error in correctCustomModelScaleConfigs:", error);
    return {};
  }
}

export function correctCustomModelPredictionInputScaleConfigs<
  T extends Partial<CustomModelPredictionInput>,
>(input: T, normalizationTarget: number = 2.0): T {
  try {
    if (input === null || typeof input !== "object") {
      return input;
    }

    const typedInput = input;
    let result: T | null = null;

    for (const key of ["scaleConfigs", "scale_configs", "modelTrainingPairs"]) {
      const configs = typedInput[key];
      if (isCustomModelScaleConfigs(configs)) {
        result = result || { ...typedInput };
        result[key] = correctCustomModelScaleConfigs(configs, normalizationTarget);
      }
    }

    return result ? (result as T) : input;
  } catch (error) {
    debugError("Error correcting prediction input scale configs:", error);
    return input;
  }
}

export enum CustomModelPlaygroundPromptEditorUpdatePromptType {
  HumanModel = "HumanModel",
  HumanAngle = "HumanAngle",
  BackgroundTemplate = "BackgroundTemplate",
  Product = "Product",
}

export interface CustomModelPlaygroundPromptEditorState {
  json: string;
  text: string;
  scaleConfigs: CustomModelScaleConfigs;
  isUpdatingPrompt: boolean;
  isUpdatingPromptType: CustomModelPlaygroundPromptEditorUpdatePromptType;
}

export enum CustomModelPredictionInputBackendType {
  SelfHost = "S",
  Fal = "F",
  FixDetail = "FixDetail",
  RegenerateHuman = "RegenerateHuman",
  UpscaleCreative = "UpscaleCreative",
  ClarityUpscale = "UpscaleCreativeV2",
  GenerateVariations = "GenerateVariations",
  InContextVariations = "ICV",
  MultiStageGeneration = "MultiStageGeneration",
}

export interface ImageSizeFalType {
  width: number;
  height: number;
}

export enum OutputFormatFalType {
  JPEG = "jpeg",
  PNG = "png",
}

export interface CustomModelPredictionInputFal {
  backendType: CustomModelPredictionInputBackendType.Fal;
  scaleConfigs: CustomModelScaleConfigs;
  prompt: string;
  promptJson?: string;
  promptEditorState?: string;
  seed?: number;
  sync_mode?: boolean;
  image_size?: ImageSizeFalType;
  num_images?: number;
  output_format?: OutputFormatFalType;
  guidance_scale?: number;
  num_inference_steps?: number;
  enable_safety_checker?: boolean;
  prompt_suffix?: string;
}

export function isCustomModelPredictionInputFal(
  input: any,
): input is CustomModelPredictionInputFal {
  return (
    input &&
    input.backendType === CustomModelPredictionInputBackendType.Fal &&
    typeof input.prompt === "string" &&
    input.scaleConfigs != null
  );
}

export interface CustomModelPredictionInputMultiStageGeneration extends MultiStageGenerationInput {
  backendType: CustomModelPredictionInputBackendType.MultiStageGeneration;
}

export function isCustomModelPredictionInputMultiStageGeneration(
  input: any,
): input is CustomModelPredictionInputMultiStageGeneration {
  return (
    input &&
    (input as CustomModelPredictionInputMultiStageGeneration).backendType ===
      CustomModelPredictionInputBackendType.MultiStageGeneration
  );
}

export interface CustomModelPredictionInputSelfHost {
  backendType: CustomModelPredictionInputBackendType.SelfHost;
  width: number;
  height: number;
  prompt: string;
  promptJson?: string;
  promptEditorState?: string;
  lora_scale?: number;
  num_outputs: number;
  aspect_ratio?: string;
  output_format?: string;
  guidance_scale: number;
  output_quality?: number;
  num_inference_steps: number;
  scale_configs?: CustomModelScaleConfigs;
  humanAngle?: string;
  backgroundTemplate?: string;
}

export function isCustomModelPredictionInputSelfHost(
  input: any,
): input is CustomModelPredictionInputSelfHost {
  return (
    input &&
    typeof input.prompt === "string" &&
    typeof input.width === "number" &&
    typeof input.height === "number" &&
    typeof input.num_inference_steps === "number"
  );
}

export interface CustomModelFixDetailsFalArgsOverride {
  num_inference_steps?: number;
  guidance_scale?: number;
  num_images?: number;
  enable_safety_checker?: boolean;
  output_format?: string;
  strength?: number;
  seed?: number;
}

export interface CustomModelPredictionInputFixDetails {
  backendType: CustomModelPredictionInputBackendType;
  shortCaption: string;
  fullCaption: string;
  promptJson?: string;
  modelTrainingPairs: CustomModelScaleConfigs;
  extraPredictionArgs?: CustomModelFixDetailsFalArgsOverride;
  image_size?: ImageSizeFalType;
}

export function isCustomModelPredictionInputFixDetails(
  input: any,
): input is CustomModelPredictionInputFixDetails {
  return (
    typeof input === "object" &&
    input !== null &&
    "backendType" in input &&
    typeof input.backendType === "string" &&
    "shortCaption" in input &&
    typeof input.shortCaption === "string" &&
    "fullCaption" in input &&
    typeof input.fullCaption === "string" &&
    "modelTrainingPairs" in input &&
    typeof input.modelTrainingPairs === "object"
  );
}

export interface CustomModelPredictionInputCreateMultiNodeWorkflow {
  backendType: CustomModelPredictionInputBackendType.SelfHost;
  graphPayloadJson: string;
  promptJson: string;
  modelTrainingPairs: CustomModelScaleConfigs;
  image_size?: ImageSizeFalType;
}

export function isCustomModelPredictionInputCreateMultiNodeWorkflow(
  input: any,
): input is CustomModelPredictionInputCreateMultiNodeWorkflow {
  return input && typeof input.graphPayloadJson === "string";
}

export type CustomModelPredictionInput =
  | CustomModelPredictionInputFal
  | CustomModelPredictionInputSelfHost
  | CustomModelPredictionInputFixDetails
  | CustomModelPredictionInputCreateMultiNodeWorkflow
  | CustomModelPredictionInputInstant;

export function getCustomModelScaleConfigsFromCustomModelPredictionInput(
  input: CustomModelPredictionInput,
): CustomModelScaleConfigs | undefined {
  try {
    if (!input) {
      return undefined;
    }

    const anyInput = input as any;
    const propNames = ["scaleConfigs", "scale_configs", "modelTrainingPairs"];

    // Find first valid scale configs
    return propNames
      .map((prop) => anyInput[prop])
      .find((value) => value !== undefined && isCustomModelScaleConfigs(value));
  } catch (error) {
    debugError("Error extracting scale configs:", error);
    return undefined;
  }
}

function clampImageLength(length: number) {
  try {
    return clamp(length, 128, 8192);
  } catch (error) {
    debugError("Error clamping image length: ", error);
    return 1024;
  }
}

export function getCustomModelPredictionInputSelfHostFromCustomModelPredictionInput(
  customModelInput: CustomModelPredictionInput,
): CustomModelPredictionInputSelfHost {
  if (isCustomModelPredictionInputSelfHost(customModelInput)) {
    return customModelInput;
  }

  if (isCustomModelPredictionInputFal(customModelInput)) {
    return {
      backendType: CustomModelPredictionInputBackendType.SelfHost,
      width: clampImageLength(customModelInput.image_size?.width ?? 1024),
      height: clampImageLength(customModelInput.image_size?.height ?? 1024),
      prompt: customModelInput.prompt,
      promptJson: customModelInput.promptJson,
      promptEditorState: customModelInput.promptEditorState,
      num_outputs: customModelInput.num_images ?? 1,
      num_inference_steps: customModelInput.num_inference_steps ?? 28,
      guidance_scale: customModelInput.guidance_scale ?? 3.5,
      scale_configs: customModelInput.scaleConfigs,
    };
  }

  if (isCustomModelPredictionInputMultiStageGeneration(customModelInput)) {
    return {
      backendType: CustomModelPredictionInputBackendType.SelfHost,
      width: clampImageLength(customModelInput.imageSize?.width ?? 1024),
      height: clampImageLength(customModelInput.imageSize?.height ?? 1024),
      prompt: customModelInput.prompt ?? "",
      promptJson: customModelInput.promptJson,
      num_outputs: customModelInput.numImages ?? 1,
      num_inference_steps: 28,
      guidance_scale: 3.5,
      scale_configs: customModelInput.customModelScaleConfigs ?? {},
    };
  }
  if (isCustomModelPredictionInputCreateMultiNodeWorkflow(customModelInput)) {
    try {
      const graphPayload = JSON.parse(customModelInput.graphPayloadJson);

      // Extract other parameters from fluxGeneralNode as before
      const firstNodeWithInputs = Object.values(graphPayload.nodes).find(
        (node: any) => node.inputs && node.type === "FluxGeneralNode",
      ) as any;

      const {
        width = 1024,
        height = 1024,
        num_inference_steps = 28,
        guidance_scale = 3.5,
        num_images = 1,
        prompt = "",
      } = firstNodeWithInputs?.inputs || {};

      // Extract custom models from the first node
      const customModels = firstNodeWithInputs.customModels || {};
      const scale_configs: CustomModelScaleConfigs = {};

      // Convert customModels format to scale_configs format
      Object.entries(customModels).forEach(([modelId, modelInfo]: [string, any]) => {
        scale_configs[modelId] = {
          scale: modelInfo.scale || 1,
          modelId: modelInfo.modelId || modelId,
          trainingId: modelInfo.trainingId || "",
        };
      });

      return {
        backendType: CustomModelPredictionInputBackendType.SelfHost,
        width: clampImageLength(width),
        height: clampImageLength(height),
        prompt,
        promptJson: customModelInput.promptJson,
        num_outputs: num_images,
        num_inference_steps,
        guidance_scale,
        scale_configs,
      };
    } catch (error) {
      debugError("Error parsing graphPayloadJson:", error);
      // Return default values if parsing fails
      return {
        backendType: CustomModelPredictionInputBackendType.SelfHost,
        width: 1024,
        height: 1024,
        prompt: "",
        promptJson: customModelInput.graphPayloadJson,
        num_outputs: 1,
        num_inference_steps: 28,
        guidance_scale: 3.5,
        scale_configs: {},
      };
    }
  }

  if (isCustomModelPredictionInputInstant(customModelInput)) {
    return {
      backendType: CustomModelPredictionInputBackendType.SelfHost,
      width: clampImageLength(customModelInput.width ?? 1024),
      height: clampImageLength(customModelInput.height ?? 1024),
      prompt: customModelInput.prompt,
      promptJson: customModelInput.promptJson,
      num_outputs: customModelInput.numOutputs ?? 1,
      num_inference_steps: 25,
      guidance_scale: 3.5,
      scale_configs: customModelInput.scaleConfigs,
    };
  }

  return {
    backendType: CustomModelPredictionInputBackendType.SelfHost,
    width: clampImageLength(customModelInput.image_size?.width ?? 1024),
    height: clampImageLength(customModelInput.image_size?.height ?? 1024),
    prompt: customModelInput.fullCaption,
    promptJson: customModelInput.promptJson,
    num_outputs: customModelInput.extraPredictionArgs?.num_images ?? 1,
    num_inference_steps: customModelInput.extraPredictionArgs?.num_inference_steps ?? 0,
    guidance_scale: customModelInput.extraPredictionArgs?.guidance_scale ?? 3.5,
    scale_configs: customModelInput.modelTrainingPairs,
  };
}

export function getImageSizeFromCustomModelPredictionInput(
  customModelInput: CustomModelPredictionInput,
): { width: number; height: number } {
  try {
    if (isCustomModelPredictionInputSelfHost(customModelInput)) {
      return { width: customModelInput.width, height: customModelInput.height };
    } else if (isCustomModelPredictionInputInstant(customModelInput)) {
      return {
        width: customModelInput.width,
        height: customModelInput.height,
      };
    } else {
      const { width = 1, height = 1 } = customModelInput?.image_size ?? {};
      return { width, height };
    }
  } catch (error) {
    debugLog("Error getting image size from prediction input: ", error);

    return { width: 1024, height: 1024 };
  }
}

export function getImageSizeFromHtmlImage(
  imageUrl: string,
): Promise<{ width: number; height: number }> {
  return new Promise((resolve, reject) => {
    try {
      const img = new Image();

      img.onload = () => {
        resolve({ width: img.naturalWidth, height: img.naturalHeight });
      };

      img.onerror = (error) => {
        debugError("Error loading image:", error);
        // Return default dimensions if image fails to load
        resolve({ width: 1024, height: 1024 });
      };

      img.src = imageUrl;
    } catch (error) {
      debugError("Error in getImageSizeFromHtmlImage:", error);
      // Return default dimensions on any error
      resolve({ width: 1024, height: 1024 });
    }
  });
}

export function getImageAspectRatioFromCustomModelPredictionInput(
  customModelInput: CustomModelPredictionInput,
) {
  const { width, height } = getImageSizeFromCustomModelPredictionInput(customModelInput);

  const aspectRatio = width / height;

  return clamp(aspectRatio, 0.1, 10);
}

export enum CustomModelAction {
  CreateNewModel = "CreateNewModel",
  StartTraining = "StartTraining",
  StopTraining = "StopTraining",
  StartPrediction = "StartPrediction",
  StopPrediction = "StopPrediction",
}

export interface CustomModelEntrypointAuthArgs {
  publicTeamId: PublicTeamId;
}

export interface HandleCreateCustomModelArgs extends CustomModelEntrypointAuthArgs {
  type: CustomModelAction.CreateNewModel;
  displayName?: string;
  customModelType: CustomModelType;
  frontendDisplayTemplateType?: FrontendDisplayTemplateType;
}

export type HandleCreateCustomModelResponse =
  | { ok: false; message: string }
  | { ok: true; id: string; displayName: string; customModelInfo: CustomModelInfo };

export type HandleCustomModelTrainingStopArgs = CustomModelEntrypointAuthArgs & {
  type: CustomModelAction.StopTraining;
  modelId: string;
  trainingId: string;
};

export type HandleCustomModelTrainingStopResponse = { ok: boolean; message: string };

export interface HandleCustomModelTrainingStartArgs extends CustomModelEntrypointAuthArgs {
  type: CustomModelAction.StartTraining;
  modelId: string;
  trainingInput: CustomModelTrainingInput;
}

export type HandleCustomModelTrainingStartResponse =
  | { ok: false; message: string }
  | { ok: true; newTrainingId: string };

export type HandleStartCustomModelPredictionArgs = CustomModelEntrypointAuthArgs & {
  type: CustomModelAction.StartPrediction;
  input: Partial<CustomModelPredictionInput>;
};

export type HandleStartCustomModelPredictionResponse =
  | { ok: false; message: string }
  | { ok: true; message: string; predictionId: string };

export type HandleStopCustomModelPredictionArgs = CustomModelEntrypointAuthArgs & {
  // userId: string,
  type: CustomModelAction.StopPrediction;
  modelId: string;
  predictionId: string;
};

export type HandleStopCustomModelPredictionResponse = { ok: boolean; message: string };

export interface UploadReferenceHumanArgs {
  faceImageStoragePath: string;
}

export interface UploadReferenceHumanFileArgs {
  faceImage: File | Blob;
  publicTeamId: PublicTeamId;
}

/** the actual hardware type that ran the job, if using GenericHardware approach. look at corresponding type in firebase functions in backend for more details */
export enum CustomModelTrainingActualHardwareType {
  Lepton = "L",
  Replicate = "R",
  Salad = "S",
}

export interface CustomModelTrainingItem {
  id: string;
  jobId?: string;
  modelId?: string;
  displayName?: string;
  caption?: string;
  captionShortened?: string;
  userModifiedCaption?: string;
  status: CustomModelTrainingStatus;
  input: CustomModelTrainingInput;
  zipFileStoragePath?: string;
  timeCreated: Timestamp;
  timeModified: Timestamp;
  progress: number;
  actualHardwareType?: CustomModelTrainingActualHardwareType; // LEGACY won't have this. in that case, can assume it's R, but has no effect on anything anyways.
  numProactiveHealthchecks?: number;
}

export function isCustomModelTrainingItem(item: any): item is CustomModelTrainingItem {
  return item && typeof item.id === "string" && item.status && item.input;
}

export function getModelTriggerWord({ modelId }: { modelId: string }) {
  return modelId.slice(0, 5).toUpperCase();
}

export interface CustomModelPredictionItem {
  id: string;
  callerPublicUserId: PublicUserId;
  status: CustomModelTrainingStatus;
  input: CustomModelPredictionInput;
  output?: string[];
  usedModels: Record<string, boolean>;
  roles: PublicUserRoles;
  timeCreated: Timestamp;
  timeModified: Timestamp;
  backendTimeStarted?: Timestamp;
  backendTimeCompleted?: Timestamp;
  isDeleted: boolean;
}

export function isCustomModelPredictionItem(item: any): item is CustomModelPredictionItem {
  return item && typeof item.id === "string" && item.status && item.input;
}

export function getDisplayNameFromId(id: string) {
  return id.slice(0, 5);
}

function getTrainingDefaultDisplayName(
  training: Partial<CustomModelTrainingItem> & { id: string },
) {
  return `Training-${getDisplayNameFromId(training.id)}`;
}

export function getTrainingDisplayName(
  training: Partial<CustomModelTrainingItem> & { id: string },
) {
  return training.displayName || getTrainingDefaultDisplayName(training);
}

export const customModelMentionTrigger = "@";

export const customModelCaptionTrigger = "[trigger]";

export const cleanCustomModelCaptionTriggerFromCaption = (caption: string) => {
  return caption.replace(customModelCaptionTrigger, "").trim().replace(/\s+/g, " ");
};

export function getModelTrainingMentionName({
  modelDisplayName,
  training,
}: {
  modelDisplayName: string;
  training: { displayName?: string; id: string };
}) {
  return customModelMentionTrigger + modelDisplayName + "/" + getTrainingDisplayName(training);
}

export function getModelIdFromTraining(training: CustomModelTrainingItem) {
  return training.modelId || training.input?.modelId;
}

export interface CustomModelEditorState {
  customModelId: string | undefined;
  setCustomModelId: (value: StateUpdater<string | undefined>) => void;
  customModels: Record<string, CustomModelInfo>;
  setCustomModels: (value: StateUpdater<Record<string, CustomModelInfo>>) => void;
  publicCustomModels: Record<string, CustomModelInfo>;
  setPublicCustomModels: (value: StateUpdater<Record<string, CustomModelInfo>>) => void;
  publicCustomModelTrainings: Record<string, CustomModelTrainingItem[]>;
  setPublicCustomModelTrainings: (
    value: StateUpdater<Record<string, CustomModelTrainingItem[]>>,
  ) => void;
  customModelInfo: CustomModelInfo | undefined;
  setCustomModelInfo: (value: StateUpdater<CustomModelInfo | undefined>) => void;
  customModelDataset: CustomModelDataset | undefined;
  setCustomModelDataset: (value: StateUpdater<CustomModelDataset | undefined>) => void;
  customModelTrainings: Record<string, CustomModelTrainingItem>;
  setCustomModelTrainings: (value: StateUpdater<Record<string, CustomModelTrainingItem>>) => void;
  allCustomModelTrainings: Record<string, CustomModelTrainingItem[]>;
  setAllCustomModelTrainings: (
    value: StateUpdater<Record<string, CustomModelTrainingItem[]>>,
  ) => void;
  customModelWorkflow: FrontendDisplayTemplateType;
  setCustomModelWorkflow: (value: StateUpdater<FrontendDisplayTemplateType>) => void;
  customModelPredictions: Record<string, CustomModelPredictionItem>;
  setCustomModelPredictions: (
    value: StateUpdater<Record<string, CustomModelPredictionItem>>,
  ) => void;
  predictionLayout: CustomModelPredictionLayout;
  setPredictionLayout: (value: StateUpdater<CustomModelPredictionLayout>) => void;
}

export function getDummyCustomModelEditorState(): CustomModelEditorState {
  return {
    customModelId: undefined,
    setCustomModelId: noop,
    customModels: {},
    setCustomModels: noop,
    publicCustomModels: {},
    setPublicCustomModels: noop,
    publicCustomModelTrainings: {},
    setPublicCustomModelTrainings: noop,
    customModelInfo: undefined,
    setCustomModelInfo: noop,
    customModelDataset: undefined,
    setCustomModelDataset: noop,
    customModelTrainings: {},
    setCustomModelTrainings: noop,
    allCustomModelTrainings: {},
    setAllCustomModelTrainings: noop,
    customModelWorkflow: FrontendDisplayTemplateType.Custom,
    setCustomModelWorkflow: noop,
    customModelPredictions: {},
    setCustomModelPredictions: noop,
    predictionLayout: CustomModelPredictionLayout.Grid,
    setPredictionLayout: noop,
  };
}

export function getDefaultCustomModelEditorState(
  set: SetEditorStateFunction,
): CustomModelEditorState {
  return {
    customModelId: undefined,
    setCustomModelId: getUpdaterFunction(set, "customModelId"),
    customModels: {},
    setCustomModels: getUpdaterFunction(set, "customModels"),
    publicCustomModels: {},
    setPublicCustomModels: getUpdaterFunction(set, "publicCustomModels"),
    publicCustomModelTrainings: {},
    setPublicCustomModelTrainings: getUpdaterFunction(set, "publicCustomModelTrainings"),
    customModelInfo: undefined,
    setCustomModelInfo: getUpdaterFunction(set, "customModelInfo"),
    customModelDataset: undefined,
    setCustomModelDataset: getUpdaterFunction(set, "customModelDataset"),
    customModelTrainings: {},
    setCustomModelTrainings: getUpdaterFunction(set, "customModelTrainings"),
    allCustomModelTrainings: {},
    setAllCustomModelTrainings: getUpdaterFunction(set, "allCustomModelTrainings"),
    customModelWorkflow: FrontendDisplayTemplateType.Custom,
    setCustomModelWorkflow: getUpdaterFunction(set, "customModelWorkflow"),
    customModelPredictions: {},
    setCustomModelPredictions: getUpdaterFunction(set, "customModelPredictions"),
    predictionLayout: CustomModelPredictionLayout.Row,
    setPredictionLayout: getUpdaterFunction(set, "predictionLayout"),
  };
}

export function resetCustomModelEditorState(state: CustomModelEditorState) {
  state.setCustomModels({});
  state.setCustomModelTrainings({});
  state.setCustomModelId(undefined);
  state.setCustomModelInfo(undefined);
  state.setCustomModelDataset(undefined);
  state.setAllCustomModelTrainings({});
  state.setPublicCustomModels({});
  state.setPublicCustomModelTrainings({});
  state.setCustomModelPredictions({});
  state.setPredictionLayout(CustomModelPredictionLayout.Row);
}

export enum CustomModelTrainingEditorStatus {
  Default,
  Cancel,
}

export enum CustomModelPromptSubject {
  Background = "background",
  HumanAngle = "human_angle",
  Human = "human",
  Product = "product",
}

export enum CustomModelBackgroundJewelryTemplateCategory {
  Hand = "Jewelry • Hand",
  Display = "Jewelry • Display",
  Background = "Jewelry • Background",
  Face = "Jewelry • Face",
}

export enum CustomModelBackgroundDefaultTemplateCategory {
  Studio = "Studio",
  Outdoor = "Outdoor",
  Indoor = "Indoor",
}

export type CustomModelBackgroundTemplateCategory =
  | CustomModelBackgroundJewelryTemplateCategory
  | CustomModelBackgroundDefaultTemplateCategory;

export enum CustomModelBackgroundTemplateType {
  Default = "Default",
  Depth = "Depth",
}

export type CustomModelBackgroundTemplate = {
  caption: string;
  shortCaption: string;
  imageUrl: string;
  category: CustomModelBackgroundTemplateCategory;
  customModelType?: CustomModelType;
};

export function isProductCustomModel(customModel: CustomModelInfo) {
  if (!customModel || !customModel.customModelType) {
    return false;
  }
  return customModel.customModelType !== CustomModelType.VirtualModel;
}

export function isFashionCustomModel(customModel: CustomModelInfo) {
  if (!customModel || !customModel.customModelType) {
    return false;
  }
  return customModel.customModelType === CustomModelType.Fashion;
}

export function isHumanCustomModel(customModel: CustomModelInfo) {
  if (!customModel || !customModel.customModelType) {
    return false;
  }
  return customModel.customModelType === CustomModelType.VirtualModel;
}

export function showHumanAngleOptions(customModel: CustomModelInfo) {
  if (!customModel || !customModel.customModelType) {
    return false;
  }
  return (
    customModel.customModelType === CustomModelType.VirtualModel ||
    isFashionCustomModel(customModel)
  );
}

export enum CustomModelPredictionLayout {
  Row = "Row",
  Grid = "Grid",
}

// use packages/frontend/src/core/common/types/custom-model-prediction-type.ts
// export enum CustomModelPredictionType {
//   Default = "Default",
//   FixDetails = "FixDetails",
//   GenerateFromPose = "GenerateFromPose",
//   FixLogoAndText = "FixLogoAndText",
//   RegenerateHuman = "RegenerateHuman",
//   UpscaleCreative = "UpscaleCreative",
//   ClarityUpscale = "ClarityUpscale",
//   GenerateVariations = "GenerateVariations",
//   InContextVariations = "ICV",
//   UpscaleFace = "UpscaleFace",
//   MultiStageGeneration = "MultiStageGeneration",
//   ICLight = "ICLight",
//   TryOn = "TryOn",
//   BuildAHuman = "BuildAHuman", // generating a diverse set of humans with different attributes, see https://docs.google.com/document/d/1o91Lc08h1QNnjeRpqp8apvBr_KPLBBoEPJG9FVV-whQ/edit
// }
