protoc-gen-custom

write your own custom protobuf code generator

SEAN K.H. LIAO

protoc-gen-custom

write your own custom protobuf code generator

protoc-gen-custom

I was thinking about a more declarative way to do flags in Go, maybe something like RegisterFlags(myStruct) and it looks at struct tags for tall the things wee need, like name overrides, help text, default values. But I didn't like the way struct tags were unstructured strings, and it didn't sound like we'd get structured tags anytime soon.

So my mind wandered over to protobuf. It has structured extension options:

1message MyMsg {
2  options (my.thing).foo = 1;
3  options (my.thing).bar = "two";
4
5  string field1 = 1 [
6    (my.thing).foo = 3, 
7    (my.thing).bar = "four"
8  ];
9}

custom extensions

Let's start with defining some custom extensions. We're adding options for messages and fields in messages. I think it's probably nicer to have a single extension with a message as a type, rather than multiple extensions as it makes for less checks in the generator code.

 1syntax = "proto3";
 2
 3package com.example.me;
 4option go_package = "me.example.com/goflag";
 5
 6import "google/protobuf/duration.proto";
 7import "google/protobuf/descriptor.proto";
 8
 9message GroupConfig {
10  string prefix = 1;
11  FlagCase case = 2;
12}
13
14enum FlagCase {
15  SNAKE = 0;
16  KEBAB = 1;
17  DOT = 2;
18}
19
20extend google.protobuf.MessageOptions {
21  optional GroupConfig flags = 59401;
22}
23
24message FlagConfig {
25  optional string name = 1;
26  string help = 2;
27}
28
29extend google.protobuf.FieldOptions {
30  optional FlagConfig flag = 59501;
31}

code generator

Now to write the code generator. protoc looks at the flags it gets, like --foo_out, and looks for executables named protoc-gen-foo in $PATH. The exectuables take protobuf messages over stdin describing the parsed proto files and writes the result to stdout.

We don't need to care about all that, the protogen package handles it for us. We just need write a callback.

For the setup, we'll loop over the input files, and only run our generator if we're told we need to generate code for it (tbh, not sure when we're passed input we don't need).

 1package main
 2
 3import (
 4        "fmt"
 5        "io"
 6        "strconv"
 7        "strings"
 8
 9        "me.example.com/goflag"
10        "google.golang.org/protobuf/compiler/protogen"
11        "google.golang.org/protobuf/proto"
12        "google.golang.org/protobuf/reflect/protoreflect"
13)
14
15func main() {
16        opt := protogen.Options{}
17        opt.Run(func(p *protogen.Plugin) error {
18                for _, file := range p.Files {
19                        if !file.Generate {
20                                continue
21                        }
22
23                        err := generateFile(p, file)
24                        if err != nil {
25                                return fmt.Errorf("%s: %w", file.Desc.Path(), err)
26                        }
27                }
28                return nil
29        })
30}

For the actual generation, we're given the filename prefix with which to attach a custom suffix. The go import path is automagically calculated for us. The final output is also formatted for us.

Go has a convention of using the comment ^// Code generated .* DO NOT EDIT\.$ to mark files as generated which should be ignored by tooling.

Looking over the messages is easy, finding how to get the options out them less so. I eventually figured (from peaking at vtprotobuf code) it was proto.GetExtension(msg.Desc.Options(), our_generated_proto.E_extension_name). With that, the rest of generation was fairly easy.

 1func generateFile(p *protogen.Plugin, file *protogen.File) error {
 2        fn := file.GeneratedFilenamePrefix + ".flag.go"
 3
 4        f := p.NewGeneratedFile(fn, file.GoImportPath)
 5
 6        io.WriteString(f, "// Code generated by protoc-gen-go-flag. DO NOT EDIT.\n\n")
 7        io.WriteString(f, "package "+string(file.GoPackageName)+"\n")
 8        io.WriteString(f, `import "flag"`)
 9        io.WriteString(f, "\n")
10
11        var didGen bool
12        for _, msg := range file.Messages {
13                if !proto.HasExtension(msg.Desc.Options(), goflag.E_Flags) {
14                        continue
15                }
16                didGen = true
17
18                groupConf := proto.GetExtension(msg.Desc.Options(), goflag.E_Flags).(*goflag.GroupConfig)
19
20                fmt.Fprintf(f, "func (x *%s) RegisterFlags(fset *flag.FlagSet) {\n", msg.GoIdent.GoName)
21                for _, field := range msg.Fields {
22                        var fieldConf *goflag.FlagConfig
23                        if proto.HasExtension(field.Desc.Options(), goflag.E_Flag) {
24                                fieldConf = proto.GetExtension(field.Desc.Options(), goflag.E_Flag).(*goflag.FlagConfig)
25                        }
26
27                        flagName := field.Desc.TextName()
28                        if n := fieldConf.GetName(); n != "" {
29                                flagName = n
30                        }
31                        if groupConf.Prefix != "" {
32                                flagName = groupConf.Prefix + "_" + flagName
33                        }
34                        switch groupConf.Case {
35                        case goflag.FlagCase_SNAKE:
36                        // noop
37                        case goflag.FlagCase_KEBAB:
38                                flagName = strings.ReplaceAll(flagName, "_", "-")
39                        case goflag.FlagCase_DOT:
40                                flagName = strings.ReplaceAll(flagName, "_", ".")
41                        }
42                        flagName = strconv.Quote(flagName)
43
44                        help := strconv.Quote(fieldConf.GetHelp())
45
46                        var varType, defaultValue string
47                        switch field.Desc.Kind() {
48                        case protoreflect.BoolKind:
49                                defaultValue = strconv.FormatBool(fieldConf.GetValBool())
50                                varType = "Bool"
51                        case protoreflect.Fixed32Kind,
52                                protoreflect.Fixed64Kind,
53                                protoreflect.Int32Kind,
54                                protoreflect.Int64Kind,
55                                protoreflect.Sfixed32Kind,
56                                protoreflect.Sfixed64Kind,
57                                protoreflect.Sint32Kind,
58                                protoreflect.Sint64Kind:
59                                defaultValue = strconv.FormatInt(fieldConf.GetValInt64(), 10)
60                                varType = "Int64"
61                                // case protoreflect.Uint32Kind, case protoreflect.Uint64Kind   :
62                                // val = fieldConf.
63                        case protoreflect.FloatKind, protoreflect.DoubleKind:
64                                defaultValue = strconv.FormatFloat(fieldConf.GetValFloat(), 'f', 64, 64)
65                                varType = "Float64"
66                        case protoreflect.StringKind:
67                                defaultValue = strconv.Quote(fieldConf.GetValString())
68                                varType = "String"
69                        case protoreflect.MessageKind:
70                                switch field.Desc.FullName() {
71                                case "google.protobuf.Duration":
72                                        dur := fieldConf.GetValDur().AsDuration().Nanoseconds()
73                                        defaultValue = strconv.FormatInt(dur, 10)
74                                        varType = "Duration"
75                                default:
76                                        continue
77                                }
78                        case protoreflect.EnumKind,
79                                protoreflect.GroupKind,
80                                protoreflect.BytesKind:
81                                continue
82                        }
83
84                        fmt.Fprintf(f, "fset.%sVar(&x.%s, %s, %s, %s)\n", varType, field.GoName, flagName, defaultValue, help)
85                }
86                fmt.Fprintf(f, "}\n")
87
88        }
89
90        if !didGen {
91                f.Skip()
92        }
93
94        return nil
95}

testing

It took a little bit of trial and error to figure out, but the fully qualified import path need to be wrapped in (), while the message fields are outside. The message can be passed either as individual fields, or as a single message in prototext representation. Annoyingly, we can't seem to refer to enum values(?).

 1syntax = "proto3";
 2
 3package com.example.me.test;
 4option go_package = "me.example.com/test";
 5
 6import "goflag/goflag.proto";
 7
 8message Config {
 9  option (com.example.me.flags) = { prefix: "myapp" case: 1 };
10  string http_host =1 [
11    (com.example.me.flag).help = "the http host",
12    (com.example.me.flag).val_string = "def"
13  ];
14  string foo = 2 [
15    (com.example.me.flag) = { 
16      name: "bar"
17      help: "some renamed flag"
18    }
19  ];
20}