3752 - Cvintete

De la Universitas MediaWiki

Se consideră numerele naturale nenule N și D urmate de o secvență S de N numere naturale nenule ordonate crescător, indexate de la 1 la N.

Cerința

Să se determine numărul de cvintete de indici (i1, i2, i3, i4, i5) ce verifică relațiile:

  • a • b • c = D
  • a • x2 + b • y2 = c2
  • a < b < c
  • x ≠ y

unde am notat cu a = S[i1], b = S[i2], c = S[i3], x = S[i4], y = S[i5]. Rezultatul se va afișa modulo 1.000.000.007.

Date de intrare

Fișierul de intrare input.txt conține pe prima linie două numere naturale nenule N și D cu semnificația

din enunț. Pe următoarea linie se vor afla N numere naturale nenule ordonate crescător.

Date de ieșire

Fișierul de ieșire output.txt va conține un singur număr natural care reprezintă rezultatul cerinței, modulo 1.000.000.007.

Exemplul 1

input.txt:

4 6

1 2 3 3

output.txt:

2

Explicație:

Cvintetele care respectă cerința sunt: (1, 2, 3, 1, 2), (1, 2, 4, 1, 2).

Exemplul 2

input.txt:

10 60

1 2 3 4 4 5 6 8 10 12

output.txt:

4

Explicație:

Cvintetele care respectă cerința sunt: (1, 6, 10, 8, 4), (1, 6, 10, 8, 5), (1, 7, 9, 2, 4), (1, 7, 9, 2, 5)

Rezolvare

import math

M = 1000000007
N = 25000
P = 1235789

def gcd(A, B):
    r = A % B
    while r != 0:
        A = B
        B = r
        r = A % B
    return B

def main():
    with open("input.txt", "r") as infile, open("output.txt", "w") as outfile:
        n, d = map(int, infile.readline().split())
        f = [0] * (N + 1)
        di = [0] * (N + 1)
        pp = [0] * (N + 1)
        v = [0] * (P + 1)
        k = 0
        m = 0

        l=list(map(int, infile.readline().split()))

        for i in range(n):
            x = l[i]
            f[x] += 1
            if x > m:
                m = x

        for i in range(1, m + 1):
            if (d % i == 0) and (f[i] > 0):
                k += 1
                di[k] = i
            pp[i] = i * i
            v[pp[i] % P] = i

        sol = 0

        for i in range(1, k - 1):
            a = di[i]
            for j in range(i + 1, k):
                b = di[j]
                e = a * b
                if d % e == 0:
                    c = d // e
                    if c > b and c <= m and f[c] > 0:
                        nr = f[a] * f[b] * f[c]
                        w = c * c
                        dc = gcd(a, b)
                        if w % dc == 0:
                            l = int(math.sqrt(w / b))
                            for u in range(1, l + 1):
                                if (w - b * pp[u]) % a == 0:
                                    z = (w - b * pp[u]) // a
                                    h = v[z % P]
                                    if h * h == z and h != u and h <= m:
                                        sol += nr * f[u] * f[h]

        outfile.write(str(sol % M) + "\n")

if __name__ == "__main__":
    main()