import { useCallback, useEffect, useRef, useState } from 'react'
import * as Sentry from '@sentry/browser'
import { hooks } from '@beelday/common'
import { InferenceSession, Tensor, TypedTensor } from 'onnxruntime-web'
import { bufferSize, modelPath } from './settings'

type UseInferenceSession = () => {
	session: InferenceSession | null
	createDataPoint: (data: Float32Array) => Promise<number | undefined>
}

export const useInferenceSession: UseInferenceSession = () => {
	const isMounted = hooks.useIsMounted()
	const hn = useRef<TypedTensor<'float32'>>()
	const cn = useRef<TypedTensor<'float32'>>()
	const [session, setSession] = useState<InferenceSession | null>(null)

	useEffect(() => {
		hn.current = new Tensor('float32', new Array(2 * 64).fill(0), [2, 1, 64])
		cn.current = new Tensor('float32', new Array(2 * 64).fill(0), [2, 1, 64])

		InferenceSession.create(modelPath, {
			executionProviders: ['wasm'],
		})
			.then(createdSession => isMounted() && setSession(createdSession))
			.catch(error => Sentry.captureException(error))
	}, [])

	const createDataPoint = useCallback(
		async data => {
			if (!session || !hn.current || !cn.current) {
				return undefined
			}

			try {
				const dataTensor = new Tensor('float32', data, [1, bufferSize])
				const output = await session.run({
					input: dataTensor,
					h0: hn.current,
					c0: cn.current,
				})

				return (output.output.data as Float32Array)[1]
			} catch (error) {
				Sentry.captureException(error)
			}
		},
		[session, hn, cn]
	)

	return {
		session,
		createDataPoint,
	}
}
