logo
L
·
Sep 12, 2025

Demonstrating the Accuracy of an AI Model through Zero-Knowledge Proof is Even Faster Than Direct Computation

cover

1. Machine Learning and Zero Knowledge Proof: How did they meet

In recent years, with in-depth research on machine learning theory and algorithms, more and more AI products have begun to achieve large-scale market applications, greatly changing the operation mode of many traditional industries. This shift is expected to drive more research, investment, and engineering focus on this field in the future. Especially in the last few years, AI large models represented by GPT have made amazing progress, evolving into indispensable tools for people's daily lives.

As the ability of AI models to handle problems continues to grow, so does the complexity of the models behind them. For example, Google recently announced their new model, Gemini, claiming it to be the most complex AI model to date. The time and money consumed to train such a complex model to achieve the desired performance are prohibitive. However, for common users, this complexity is hidden behind simple APIs, making it challenging to understand the model's specific running process.

From the perspective of AI service providers, this design is reasonable. The training process requires a significant investment of time and money, and the completed parameter set is a valuable asset of the company. It ensures that it will not be leaked during usage while providing convenience to users and lowering the threshold to use AI models.

However, from the user's perspective, this design raises a problem: How can we be sure that the AI model vendor has provided us the right model? This question may sound confusing, but we can understand it through a simple example: Suppose a website needs to use an AI model that classifies ads based on customer groups, and the AI model vendor (e.g., Google) may have multiple models with different accuracy levels and costs. As a user, the site wants the model with 90% accuracy at $20,000, but the vendor provides a model with 85% accuracy. If the difference in prediction results is indistinguishable, the site is unable to tell which model the vendor is providing.

This example may still be confusing because, as users, we tend to focus more on the actual results of the model we used and don't care what kind of model was used. Nevertheless, not all results of AI models can be measured simply in the vast scenario of AI usage. For instance, in scenarios like medical diagnosis, financial market investments, or predicting enemy actions in a war, the prediction results are significant, and a wrong prediction is unacceptable.

Fortunately, zero-knowledge proofs (ZKP) can be a perfect solution to this paradox. Generally speaking, zero-knowledge proof is a proof system that utilizes cryptographic techniques to ensure that no private information about the prover is disclosed during the proof process. For more details about zero-knowledge proofs, readers can refer to this article, which will not be described in detail here.

However, it is very difficult to prove a complex AI model directly by ZKP. In ZKP, it is necessary to convert AI model algorithms to some hardware description language (HDL), such as R1CS, which only allows multiplication gates and addition gates to represent arithmetic processes. Converting a complex AI model to HDL description is often very inefficient. For example, it takes about 30 minutes to prove a VGG16 model (a common CNN model). For more complex AI models such as the Twitter recommendation algorithm, directly proving them with ZKP is impractical, taking about 6 hours to generate the proof.

Therefore, proving an AI model directly with ZKP seems unwise. We need other ways to reduce the complexity of the proofs. We need to clarify that our goal is to "prove a model" rather than "compute a model again with ZKP." That is, we do not want to "translate" the model's algorithms into some ZKP description because the complexity of proving an algorithm can, in some cases, be even less than the complexity of computing the algorithm itself!

2. Difference between Validation and Computation

Suppose we have three matrices on hand, AA, BB, and CC. We want to prove that AB=CAB = C. The most intuitive way is to evaluate every element of ABAB and check if it is equal to CC in every position, i.e., cij=?k=1naikbkjc_{ij} \overset{\text{?}}{=} \sum_{k=1}^{n} a_{ik}b_{kj}. Alternatively, we could multiply by another vector on the right and left simultaneously and check (AB)x=?Cx(AB)x \overset{\text{?}}{=} Cx, where x=(x1,x2,,xn)\boldsymbol{x} = (x_1, x_2, \ldots, x_n). What is the difference? It may seem counterintuitive as it introduces more computation at first sight. Let's consider more details.

A=(a11a1nan1ann),B=(b11b1nbn1bnn),C=(c11c1ncn1cnn)A = \begin{pmatrix} a_{11} & \cdots & a_{1n} \\ \vdots & \ddots & \vdots \\ a_{n1} & \cdots & a_{nn} \end{pmatrix}, \quad B = \begin{pmatrix} b_{11} & \cdots & b_{1n} \\ \vdots & \ddots & \vdots \\ b_{n1} & \cdots & b_{nn} \end{pmatrix}, \quad C = \begin{pmatrix} c_{11} & \cdots & c_{1n} \\ \vdots & \ddots & \vdots \\ c_{n1} & \cdots & c_{nn} \end{pmatrix} (AB)x=?CxA(Bx)=?Cx(AB)\boldsymbol{x} \overset{\text{?}}{=} C\boldsymbol{x} \Rightarrow A(B\boldsymbol{x}) \overset{\text{?}}{=} C\boldsymbol{x}

In the first case, we need to evaluate nn multiplications for every cijc_{ij} in CC. Given that the dimension of CC is n×nn \times n, we need n3n^3 multiplications, resulting in a time complexity of O(n3)\mathcal{O}(n^3). Compared with evaluating ABAB directly, in the second case, we evaluate B×xB \times \boldsymbol{x} firstly, where x\boldsymbol{x} is an nn-size vector. This operation needs n2n^2 multiplications, and the result BxB\boldsymbol{x} is another nn-size vector. Then we evaluate A(Bx)A(B\boldsymbol{x}), still requiring n2n^2 multiplications. So, we totally need 2n22n^2 multiplications, and the complexity is O(n2)\mathcal{O}(n^2). This is known as Freivalds’ Algorithm given by Rūsiņš Mārtiņš Freivalds in 1979.

This is the core message we want to convey: "Validation" does not equate to "recalculation." With some mathematical methods and techniques, it is possible to prove a complex algorithm with reduced complexity. In the next section, we will see how to use this idea to prove the convolutional layer in CNN.

3. Example: Validation of 2-D Convolution Even Faster Than Computation

All the content in this section is sourced from zkCNN: Zero Knowledge Proofs for Convolutional Neural Network Predictions and Accuracy published in 2021. The description in this section involves some mathematical theories and derivations. To maximize reader convenience, many theoretical derivations and specific algorithms have been omitted, and some expressions are actually wrong in the formal definitions. For a comprehensive understanding, we strongly recommend reading the original paper. You can find the paper here.

The sumcheck protocol is one of the interactive proofs in the literature. The sumcheck problem is to verify the summation of a multivariate polynomial gg on all binary inputs. Simply put, the purpose of the protocol is to check this evaluation:

H:=b1,b2,,bl{0,1}g(b1,b2,,bl)H := \sum_{b_{1},b_{2},\ldots,b_{l}\in \{0,1\}} g(b_1,b_2,\ldots,b_l)

The time complexity of the prover is about O(2l)\mathcal{O}(2^{l}), and the proof size is O(l)\mathcal{O}(l).

2-D convolution is a very common operation in CNN. The 2-D convolution operation is AWA*W, where AA is an n×nn\times n matrix and WW is a w×ww\times w matrix, with nn much greater than ww. When we validate 2-D convolutions in CNN, the computation could be reduced to a 1-D convolution. It is well-known that 1-D convolution is the same as multiplications between two univariate polynomials, and FFT/IFFT is a widely used technique to evaluate univariate polynomial multiplication. The algorithm is:

A(x)W(x)=IFFT(FFT(A(x))FFT(W(x)))A(x)*W(x) = \text{IFFT}(\text{FFT}(A(x))\circ \text{FFT}(W(x)))

where \circ denotes element-wise product. If we ignore FFT and IFFT, the complexity of this evaluation for an n×nn\times n matrix is O(n2)\mathcal{O}(n^2). The FFT and IFFT evaluation is:

f^j=FFT(f)=i=0n1fiωij,fi=IFFT(f^)=1mj=0m1f^jωij\hat{f}_j = \text{FFT}(f) = \sum_{i=0}^{n-1} f_i\omega^{ij}, \quad f_i = \text{IFFT}(\hat{f}) = \frac{1}{m} \sum_{j=0}^{m-1} \hat{f}_j\omega^{-ij}

If we directly evaluate FFT(A)\text{FFT}(A), the complexity is O(n2logn2)\mathcal{O}(n^{2}\log n^{2}). Instead of validating the FFT operation directly, we will use the sumcheck protocol in FFT with some clever observations. Formally speaking, let c=(c0,c1,,cn1)\boldsymbol{c} = (c_0,c_1,\ldots,c_{n-1}) be the vector of coefficients of a polynomial, and a=(a0,a1,,an1)\boldsymbol{a} = (a_0,a_1,\ldots,a_{n-1}) be the vector of evaluations at (ω0,ω1,,ωm)(\omega^{0},\omega^{1},\ldots,\omega^{m}) where ω\omega is the mm-th root of unity such that ωm=1\omega^{m} = 1. Then we use the FFT formulation we have a=Fc\boldsymbol{a} = F\boldsymbol{c}, where FF is:

F=(111111ω1ω2ωn2ωn11ω2ω4ω2(n2)ω2(n1)1ωm2ω2(m2)ω(m2)(n2)ω(m2)(n1)1ωm1ω2(m1)ω(m1)(n2)ω(m1)(n1))F = \begin{pmatrix} 1 & 1 & 1 & \cdots & 1 & 1 \\ 1 & \omega^{1} & \omega^{2} & \cdots & \omega^{n-2} & \omega^{n-1} \\ 1 & \omega^{2} & \omega^{4} & \cdots & \omega^{2(n-2)} & \omega^{2(n-1)} \\ \vdots& \vdots & \vdots & \ddots &\vdots &\vdots \\ 1 & \omega^{m-2} & \omega^{2(m-2)} & \cdots & \omega^{(m-2)(n-2)} & \omega^{(m-2)(n-1)} \\ 1 & \omega^{m-1} & \omega^{2(m-1)} & \cdots & \omega^{(m-1)(n-2)} & \omega^{(m-1)(n-1)} \end{pmatrix} a(y)=x{0,1}lognF(y,x)c(x)a(y) = \sum_{x\in \{0,1\}^{\log n}} F(y,x)c(x)

In practice, a,c,Fa, c, F would be the multilinear extensions of themselves:

a~(y)=x{0,1}lognF~(y,x)c~(x)\widetilde{a}(y) = \sum_{x\in \{0,1\}^{\log n}} \widetilde{F}(y,x)\widetilde{c}(x)

Suppose we have a polynomial f(x)f(x), x{0,1}lognx \in \{0,1\}^{\log n}. The complexity to evaluate its multilinear extension f~(x)\widetilde{f}(x), xFlognx \in {\mathbb{F}}^{\log n} is O(n)\mathcal{O}(n). For the matrix polynomial F(x,y)F(x,y), x{0,1}lognx \in \{0,1\}^{\log n}, y{0,1}logmy \in \{0,1\}^{\log m}, the complexity to evaluate F~(x,y)\widetilde{F}(x,y), xFlognx \in {\mathbb{F}}^{\log n}, yFlogmy \in {\mathbb{F}}^{\log m} is O(nm)\mathcal{O}(nm). Generally speaking, we evaluate F~(x,y)\widetilde{F}(x,y), x{0,1}lognx \in \{0, 1\}^{\log n}, yFlogmy \in \mathbb{F}^{\log m} firstly, and the complexity is O(n)\mathcal{O}(n). The final evaluation is:

F~(y,x)=i=0logm1((1yi)+yiω2i+1X)yiF,X[n]\widetilde{F}(y,x) = \prod_{i=0}^{\log {m-1}} ((1-y_i) + y_i \cdot \omega_{2^{i+1}}^\mathcal{X}) \qquad y_i \in {\mathbb{F}}, \mathcal{X} \in [n]

where ω2i+1=ωm2i+1\omega_{2^{i+1}} = \omega^{\frac{m}{2^{i+1}}}. Because ω\omega is an mm-th root of unity, we have ωm=1\omega^{m}=1. That means ω2i+1χ\omega_{2^{i+1}}^\chi at most has 2i+12^{i+1} distinct values for all X[n]\mathcal{X}\in[n]. We could precompute all mm distinct values of ω\omega, and the complexity is O(m)\mathcal{O}(m). So the total complexity of validating FFT is O(n+m)\mathcal{O}(n+m).

Recalling for 2-D convolution, a(x)a(x) is actually an n×nn\times n matrix, so the complexity of validating FFT(a)\text{FFT}(a) is O(n2)\mathcal{O}(n^2). The complexity of element-wise production is also O(n2)\mathcal{O}(n^2). So the total complexity is O(n2)\mathcal{O}(n^2), which is even faster than computing the convolution!!!

Reference