#!/usr/bin/env python3
# Copyright (C) 2022 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import re
import subprocess
import pathlib
import tempfile
import contextlib
import argparse
import itertools

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TOOLS_DIR = os.path.join(ROOT_DIR, "tools")
OUT_DIR = os.path.join(ROOT_DIR, "out", "tools")
NINJA = os.path.join(TOOLS_DIR, "ninja")
GN = os.path.join(TOOLS_DIR, "gn")
PROTOC_PATH = os.path.join(OUT_DIR, "protoc")
DESCRIPTOR_PATH = os.path.join(ROOT_DIR, "src", "trace_processor", "importers",
                               "proto", "atoms.descriptor")
PROTOBUF_BUILTINS_DIR = os.path.join(ROOT_DIR, "buildtools", "protobuf", "src")
PROTO_LOGGING_URL = "https://android.googlesource.com/platform/frameworks/proto_logging.git"
ATOM_RE = r"  message_type {\n.   name: \"Atom\"(\n    .+)+(\n  })"
FIELD_RE = r"    field {\n      name: \"([^\"]+)\"\n      number: ([0-9]+)"
EXTENSIONS_RE = r" extension {\n    name: \"([^\"]+)\"\n    extendee: \".android.os.statsd.Atom\"\n    number: ([0-9]+)(\n    .+)+(\n  })"
ATOM_IDS_PATH = os.path.join(ROOT_DIR, "protos", "perfetto", "config", "statsd",
                             "atom_ids.proto")

ATOM_IDS_TEMPLATE = """/*
 * Copyright (C) 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
syntax = "proto2";

package perfetto.protos;

// This enum is obtained by post-processing
// AOSP/frameworks/proto_logging/stats/atoms.proto through
// AOSP/external/perfetto/tools/update-statsd-descriptor, which extracts one
// enum value for each proto field defined in the upstream atoms.proto.
enum AtomId {{
  ATOM_UNSPECIFIED = 0;
{atoms}
}}"""


def call(*cmd, stdin=None):
  try:
    return subprocess.check_output(cmd, stdin=stdin)
  except subprocess.CalledProcessError as e:
    print("Error running the command:")
    print(" ".join(cmd))
    print(e)
    exit(1)


# Extract core atoms. To do this we regex the pbtext
# of the descriptor. This is hopefully:
# - more stable than regexing atom.proto directly
# - less complicated than parsing finding, importing, and using the
#   Python protobuf library.
def atoms_from_descriptor():
  with contextlib.ExitStack() as stack:
    descriptor_in = stack.enter_context(open(DESCRIPTOR_PATH))
    pbtext = call(
        PROTOC_PATH,
        f"--proto_path={PROTOBUF_BUILTINS_DIR}",
        f"{PROTOBUF_BUILTINS_DIR}/google/protobuf/descriptor.proto",
        "--decode=google.protobuf.FileDescriptorSet",
        stdin=descriptor_in).decode("utf8")

    # Core atoms:
    atom_pbtext = re.search(ATOM_RE, pbtext, re.MULTILINE)[0]
    for m in re.finditer(FIELD_RE, atom_pbtext):
      yield m[1], m[2]

    # Extensions
    for m in re.finditer(EXTENSIONS_RE, pbtext):
      yield m[1], m[2]


def main():
  parser = argparse.ArgumentParser()
  parser.add_argument("--atoms-checkout")
  args = parser.parse_args()

  call(GN, "gen", OUT_DIR, "--args=is_debug=false")
  call(NINJA, "-C", OUT_DIR, "protoc")

  with contextlib.ExitStack() as stack:

    # Write the descriptor.
    if args.atoms_checkout:
      atoms_root = args.atoms_checkout
      proto_logging_dir = os.path.join(atoms_root, "frameworks",
                                       "proto_logging")
    else:
      atoms_root = stack.enter_context(tempfile.TemporaryDirectory())
      proto_logging_dir = os.path.join(atoms_root, "frameworks",
                                       "proto_logging")
      pathlib.Path(proto_logging_dir).mkdir(parents=True, exist_ok=True)
      call("git", "clone", PROTO_LOGGING_URL, proto_logging_dir)

    extensions_path = os.path.join(proto_logging_dir, "stats", "atoms")
    extensions = []
    if os.path.isdir(extensions_path):
      for dirpath, dirnames, filenames in os.walk(extensions_path):
        for name in filenames:
          if name.endswith(".proto"):
            path = os.path.join(dirpath, name)
            extensions.append(path)

    cmd = [
        f"--proto_path={PROTOBUF_BUILTINS_DIR}",
        f"--proto_path={atoms_root}",
        f"--descriptor_set_out={DESCRIPTOR_PATH}",
        "--include_imports",
    ] + extensions + [os.path.join(proto_logging_dir, "stats", "atoms.proto")]
    call(PROTOC_PATH, *cmd)

    lines = []
    for name, field in atoms_from_descriptor():
      name = "ATOM_" + name.upper()
      lines.append(f"  {name} = {field};".format(name=name, field=field))
    atom_ids_out = stack.enter_context(open(ATOM_IDS_PATH, "w"))
    atom_ids_out.write(ATOM_IDS_TEMPLATE.format(atoms="\n".join(lines)))


if __name__ == "__main__":
  sys.exit(main())
