aboutsummaryrefslogtreecommitdiffstats
path: root/pish_derive/src/lib.rs
blob: 3a296fcffa770bf4956d7dd98acd7d2b2224c219 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use proc_macro::TokenStream;
use quote::{ToTokens, quote};
use syn::{Data, DeriveInput, Fields, parse_macro_input};

#[proc_macro_derive(FromArgs)]
pub fn derive_cli(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    let name = input.ident;

    let fields = match input.data {
        Data::Struct(data) => data.fields,
        _ => panic!("Cli can only be derived for structs"),
    };

    let mut field_parsers_init = Vec::new();
    let mut field_parsers = Vec::new();
    let mut field_parsers_post = Vec::new();
    let mut field_names = Vec::new();

    if let Fields::Named(fields_named) = fields {
        for field in fields_named.named {
            let ident = field.ident.unwrap();
            let name_str = ident.to_string();

            let mut long_name = b"--".to_vec();
            long_name.extend_from_slice(name_str.as_bytes());
            let long_name = proc_macro2::Literal::byte_string(&long_name);

            field_names.push(ident.clone());

            // initialization
            field_parsers_init.push(quote! {
                let mut #ident = None;
            });

            let is_bool = field.ty.to_token_stream().to_string() == String::from("bool");
            let is_option = field.ty.to_token_stream().to_string().starts_with("Option"); // bad bad detection

            // in the loop
            if is_bool {
                field_parsers.push(quote! {
                    if arg == #long_name {
                        #ident = Some(true);
                        continue;
                    }
                });
            } else {
                field_parsers.push(quote! {
                    if arg == #long_name {
                        let Some(val) = iter.next() else {
                            return Err(ArgParseError::MissingArgValue(#name_str));
                        };
                        match String::from_utf8_lossy(val).parse() {
                            Ok(parsed) => {
                                #ident = Some(parsed);
                                continue;
                            }
                            Err(err) => {
                                return Err(ArgParseError::ArgValueParseError(#name_str, format!("{err:?}")));
                            }
                        }
                    }
                });
            }

            // after loop
            if is_bool {
                field_parsers_post.push(quote! { let #ident = #ident.unwrap_or(false); });
            } else if !is_option {
                field_parsers_post.push(quote!{ let Some(#ident) = #ident else { return Err(ArgParseError::MissingArg(#name_str)) }; });
            }
        }
    }

    let expanded = quote! {
        impl ArgParse for #name {
            fn parse<'a>(args: &'a [BString]) -> std::result::Result<Self, ArgParseError<'a>> {
                let mut iter = args.iter();

                #(#field_parsers_init)*

                while let Some(arg) = iter.next() {
                    #(#field_parsers)*;
                    return Err(ArgParseError::LeftoverArg(arg));
                }

                #(#field_parsers_post)*

                Ok(Self { #( #field_names ),* })
            }
        }
    };

    TokenStream::from(expanded)
}