summaryrefslogtreecommitdiff
path: root/src/xpit_/args_helpers.py
blob: 1e79911201768cbac0925c7a2770cce54934302d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from sys import stderr
from . import error
Err = error.Err

def print_error(msg: str) -> None:
    stderr.write(f"{msg}\n")

def err_arg_unrecognized(arg: str) -> None:
    print_error(f"Unrecognized argument: {arg}")

ERR_UNRECOGNIZED_ARG = "ERR_UNRECOGNIZED_ARG"
ERR_NOT_ENOUGH_ARGS = "ERR_NOT_ENOUGH_ARGS"

def parse_arg(
    argv: list[str],
    i: int,
    flag_value_map: dict[str, int]
) -> tuple[str, list[str], int, Err | None]:
    def success(
        flag: str,
        values: list[str],
        new_i: int
    ) -> tuple[str, list[str], int, Err | None]:
        return flag, values, new_i, None
    def fail(msg: str) -> tuple[str, list[str], int, Err | None]:
        return "", [], 0, Err(msg)

    flag: str
    values: list[str] = []
    arg: str = argv[i]
    i += 1
    if arg[0] != '-':
        err_arg_unrecognized(arg)
        return fail(ERR_UNRECOGNIZED_ARG)

    value_index = 0
    if arg[1] == '-':
        # --flag
        delimiter_index = arg.find('=')
        if delimiter_index == -1:
            flag = arg[2:]
            value_index = 0
        else:
            flag = arg[2:delimiter_index]
            values.append(arg[delimiter_index+1:])
            value_index = 1

    else:
        # -f
        flag = arg[1]
        value_index = 0

        if len(arg) != 2:
            value = arg[2:]
            if value[0] == '=':
                value = value[1:]
            values.append(value)
            value_index += 1
    target_flag_set_arr = None
    value_count: int = 0
    for flag_set in flag_value_map:
        flag_set_arr = flag_set.split('|')
        if flag in flag_set_arr:
            target_flag_set_arr = flag_set_arr
            value_count = flag_value_map[flag_set]
            break
    if target_flag_set_arr is None:
        err_arg_unrecognized(flag)
        return fail(ERR_UNRECOGNIZED_ARG)
    while value_index < value_count:
        if i >= len(argv):
            print_error(f"Not enough arguments for flag '{flag}'")
            return fail(ERR_NOT_ENOUGH_ARGS)
        values.append(argv[i])
        i += 1
        value_index += 1

    # NOTE: we always return the first flag in the flag-set for normalization
    return success(target_flag_set_arr[0], values, i)

def parse_generic(
    flag_value_map: dict[str, int],
    argv: list[str],
    i: int
) -> dict[str, list[str]] | None:
    res: dict[str, list[str]] = {}
    while i < len(argv):
        arg = argv[i]
        if len(arg) >= 2:
            flag, values, i, err = parse_arg(argv, i, flag_value_map)
            if err is not None:
                return None

            res[flag] = values
            #if not callback(flag, values):
            #    return False
        else:
            err_arg_unrecognized(arg)
            return None

    return res