3752 - Cvintete

De la Universitas MediaWiki
Versiunea pentru tipărire nu mai este suportată și poate avea erori de randare. Vă rugăm să vă actualizați bookmarkurile browserului și să folosiți funcția implicită de tipărire a browserului.

Se consideră numerele naturale nenule N și D urmate de o secvență S de N numere naturale nenule ordonate crescător, indexate de la 1 la N.

Cerința

Să se determine numărul de cvintete de indici (i1, i2, i3, i4, i5) ce verifică relațiile:

  • a • b • c = D
  • a • x2 + b • y2 = c2
  • a < b < c
  • x ≠ y

unde am notat cu a = S[i1], b = S[i2], c = S[i3], x = S[i4], y = S[i5]. Rezultatul se va afișa modulo 1.000.000.007.

Date de intrare

Fișierul de intrare input.txt conține pe prima linie două numere naturale nenule N și D cu semnificația

din enunț. Pe următoarea linie se vor afla N numere naturale nenule ordonate crescător.

Date de ieșire

Fișierul de ieșire output.txt va conține un singur număr natural care reprezintă rezultatul cerinței, modulo 1.000.000.007.

Exemplul 1

input.txt:

4 6

1 2 3 3

output.txt:

2

Explicație:

Cvintetele care respectă cerința sunt: (1, 2, 3, 1, 2), (1, 2, 4, 1, 2).

Exemplul 2

input.txt:

10 60

1 2 3 4 4 5 6 8 10 12

output.txt:

4

Explicație:

Cvintetele care respectă cerința sunt: (1, 6, 10, 8, 4), (1, 6, 10, 8, 5), (1, 7, 9, 2, 4), (1, 7, 9, 2, 5)

Rezolvare

import math

M = 1000000007
N = 25000
P = 1235789

def gcd(A, B):
    r = A % B
    while r != 0:
        A = B
        B = r
        r = A % B
    return B

def main():
    with open("input.txt", "r") as infile, open("output.txt", "w") as outfile:
        n, d = map(int, infile.readline().split())
        f = [0] * (N + 1)
        di = [0] * (N + 1)
        pp = [0] * (N + 1)
        v = [0] * (P + 1)
        k = 0
        m = 0

        l=list(map(int, infile.readline().split()))

        for i in range(n):
            x = l[i]
            f[x] += 1
            if x > m:
                m = x

        for i in range(1, m + 1):
            if (d % i == 0) and (f[i] > 0):
                k += 1
                di[k] = i
            pp[i] = i * i
            v[pp[i] % P] = i

        sol = 0

        for i in range(1, k - 1):
            a = di[i]
            for j in range(i + 1, k):
                b = di[j]
                e = a * b
                if d % e == 0:
                    c = d // e
                    if c > b and c <= m and f[c] > 0:
                        nr = f[a] * f[b] * f[c]
                        w = c * c
                        dc = gcd(a, b)
                        if w % dc == 0:
                            l = int(math.sqrt(w / b))
                            for u in range(1, l + 1):
                                if (w - b * pp[u]) % a == 0:
                                    z = (w - b * pp[u]) // a
                                    h = v[z % P]
                                    if h * h == z and h != u and h <= m:
                                        sol += nr * f[u] * f[h]

        outfile.write(str(sol % M) + "\n")

if __name__ == "__main__":
    main()