Skip to content

Commit c6a5a5f

Browse files
committed
fix descriptor recursion
1 parent 892f9b3 commit c6a5a5f

File tree

9 files changed

+372
-227
lines changed

9 files changed

+372
-227
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ priv/plts
3131
# Ignore cursor and memory-bank
3232
.cursor/
3333
memory-bank/
34+
custom_modes/

lib/grpc_reflection/service/builder.ex

Lines changed: 200 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -2,149 +2,243 @@ defmodule GrpcReflection.Service.Builder do
22
@moduledoc false
33

44
alias Google.Protobuf.FileDescriptorProto
5-
alias GrpcReflection.Service.State
6-
alias GrpcReflection.Service.Builder.Util
5+
alias GrpcReflection.Service.Builder.Acc
76
alias GrpcReflection.Service.Builder.Extensions
7+
alias GrpcReflection.Service.Builder.Util
8+
alias GrpcReflection.Service.State
89

10+
@spec build_reflection_tree(any()) ::
11+
{:error, <<_::216>>} | {:ok, GrpcReflection.Service.State.t()}
912
def build_reflection_tree(services) do
1013
with :ok <- Util.validate_services(services) do
11-
tree =
12-
Enum.reduce(services, State.new(services), fn service, state ->
13-
new_state = process_service(service)
14-
State.merge(state, new_state)
14+
acc =
15+
Enum.reduce(services, %Acc{services: services}, fn service, acc ->
16+
process_service(acc, service)
1517
end)
1618

17-
{:ok, tree}
19+
{:ok, finalize(acc)}
1820
end
1921
end
2022

21-
defp process_service(service) do
22-
service_name = service.__meta__(:name)
23-
service_response = build_response(service_name, service)
23+
defp finalize(%Acc{} = acc) do
24+
files = build_files(acc)
25+
26+
symbols =
27+
acc.symbol_info
28+
|> Enum.reduce(%{}, fn {symbol, info}, map ->
29+
Map.put(map, symbol, Map.fetch!(files, info.file))
30+
end)
31+
|> then(fn base ->
32+
Enum.reduce(acc.aliases, base, fn {alias_symbol, target_symbol}, map ->
33+
payload = Map.fetch!(map, target_symbol)
34+
Map.put(map, alias_symbol, payload)
35+
end)
36+
end)
37+
38+
State.new(acc.services)
39+
|> State.add_files(Map.merge(files, acc.extension_files))
40+
|> State.add_symbols(symbols)
41+
|> State.add_extensions(acc.extensions)
42+
end
43+
44+
defp build_files(%Acc{} = acc) do
45+
acc.symbol_info
46+
|> Enum.group_by(fn {_symbol, info} -> info.file end)
47+
|> Enum.reduce(%{}, fn {file_name, entries}, files ->
48+
descriptors = Enum.map(entries, fn {_symbol, info} -> info end)
49+
50+
syntax = descriptors |> List.first() |> Map.fetch!(:syntax)
51+
package = descriptors |> List.first() |> Map.fetch!(:package)
52+
53+
dependencies =
54+
descriptors
55+
|> Enum.flat_map(& &1.deps)
56+
|> Enum.map(&resolve_dependency_file(&1, acc))
57+
|> Enum.reject(&is_nil/1)
58+
|> Enum.reject(&(&1 == file_name))
59+
|> Enum.uniq()
60+
61+
{messages, enums, services} =
62+
Enum.reduce(descriptors, {[], [], []}, fn %{descriptor: descriptor},
63+
{messages, enums, services} ->
64+
case descriptor do
65+
%Google.Protobuf.DescriptorProto{} ->
66+
{[descriptor | messages], enums, services}
67+
68+
%Google.Protobuf.EnumDescriptorProto{} ->
69+
{messages, [descriptor | enums], services}
70+
71+
%Google.Protobuf.ServiceDescriptorProto{} ->
72+
{messages, enums, [descriptor | services]}
73+
end
74+
end)
75+
|> then(fn {messages, enums, services} ->
76+
{Enum.reverse(messages), Enum.reverse(enums), Enum.reverse(services)}
77+
end)
78+
79+
file_proto = %FileDescriptorProto{
80+
name: file_name,
81+
package: package,
82+
dependency: dependencies,
83+
syntax: syntax,
84+
message_type: messages,
85+
enum_type: enums,
86+
service: services
87+
}
2488

25-
State.new()
26-
|> State.add_symbols(%{service_name => service_response})
27-
|> State.add_files(%{(service_name <> ".proto") => service_response})
28-
|> trace_service_refs(service)
89+
Map.put(files, file_name, %{file_descriptor_proto: [FileDescriptorProto.encode(file_proto)]})
90+
end)
2991
end
3092

31-
defp trace_service_refs(state, module) do
32-
service_name = module.__meta__(:name)
33-
methods = get_descriptor(module).method
93+
defp process_service(%Acc{} = acc, service) do
94+
service_name = service.__meta__(:name)
95+
{acc, _} = register_symbol(acc, service_name, service, :service)
96+
97+
methods = get_descriptor(service).method
3498

35-
module.__rpc_calls__()
36-
|> Enum.reduce(state, fn call, state ->
37-
{function, {request, _}, {response, _}, _} = call
99+
Enum.reduce(service.__rpc_calls__(), acc, fn call, acc ->
100+
{function, {req, _}, {resp, _}, _} = call
38101

39102
%{input_type: req_symbol, output_type: resp_symbol} =
40103
Enum.find(methods, fn method -> method.name == to_string(function) end)
41104

42-
call_symbol = service_name <> "." <> to_string(function)
43-
call_response = build_response(service_name, module)
44105
req_symbol = Util.trim_symbol(req_symbol)
45-
req_response = build_response(req_symbol, request)
46106
resp_symbol = Util.trim_symbol(resp_symbol)
47-
resp_response = build_response(resp_symbol, response)
48-
49-
state
50-
|> Extensions.add_extensions(service_name, module)
51-
|> State.add_symbols(%{
52-
call_symbol => call_response,
53-
req_symbol => req_response,
54-
resp_symbol => resp_response
55-
})
56-
|> State.add_files(%{
57-
(req_symbol <> ".proto") => req_response,
58-
(resp_symbol <> ".proto") => resp_response
59-
})
60-
|> Extensions.add_extensions(req_symbol, request)
61-
|> Extensions.add_extensions(resp_symbol, response)
62-
|> trace_message_refs(req_symbol, request)
63-
|> trace_message_refs(resp_symbol, response)
107+
108+
method_symbol = service_name <> "." <> to_string(function)
109+
110+
acc
111+
|> register_alias(method_symbol, service_name)
112+
|> Extensions.add_extensions(service_name, service)
113+
|> process_message(req_symbol, req)
114+
|> process_message(resp_symbol, resp)
64115
end)
65116
end
66117

67-
defp trace_message_refs(state, parent_symbol, module) do
68-
case module.descriptor() do
69-
%{field: fields} ->
70-
trace_message_fields(state, parent_symbol, module, fields)
118+
defp process_message(%Acc{} = acc, nil, _module, _root_symbol), do: acc
71119

72-
_ ->
73-
state
74-
end
75-
end
120+
defp process_message(%Acc{} = acc, symbol, module, root_symbol) do
121+
symbol = Util.trim_symbol(symbol)
122+
root_symbol = root_symbol || symbol
76123

77-
defp trace_message_fields(state, parent_symbol, module, fields) do
78-
# nested types arent a "separate file", they return their parents' response
79-
nested_types = Util.get_nested_types(parent_symbol, module.descriptor())
80-
81-
module.__message_props__().field_props
82-
|> Map.values()
83-
|> Enum.map(fn %{name: name, type: type} ->
84-
%{
85-
mod:
86-
case type do
87-
{_, mod} -> mod
88-
mod -> mod
89-
end,
90-
symbol: Enum.find(fields, fn f -> f.name == name end).type_name
91-
}
92-
end)
93-
|> Enum.reject(fn %{symbol: s} -> s == nil end)
94-
|> Enum.reduce(state, fn %{mod: mod, symbol: symbol}, state ->
95-
symbol = Util.trim_symbol(symbol)
124+
{acc, already_processed} =
125+
if root_symbol == symbol do
126+
register_symbol(acc, symbol, module, :message)
127+
else
128+
acc = register_alias(acc, symbol, root_symbol)
96129

97-
response =
98-
if symbol in nested_types do
99-
build_response(parent_symbol, module)
130+
if MapSet.member?(acc.visited, symbol) do
131+
{acc, true}
100132
else
101-
build_response(symbol, mod)
133+
{acc, false}
102134
end
135+
end
103136

104-
state
105-
|> Extensions.add_extensions(symbol, mod)
106-
|> State.add_symbols(%{symbol => response})
107-
|> State.add_files(%{(symbol <> ".proto") => response})
108-
|> trace_message_refs(symbol, mod)
109-
end)
137+
if already_processed do
138+
acc
139+
else
140+
acc = %{acc | visited: MapSet.put(acc.visited, symbol)}
141+
acc = Extensions.add_extensions(acc, symbol, module)
142+
143+
case module.descriptor() do
144+
%{field: fields} = descriptor ->
145+
nested_symbols = Util.get_nested_types(symbol, descriptor)
146+
147+
module.__message_props__().field_props
148+
|> Map.values()
149+
|> Enum.map(fn %{name: name, type: type} ->
150+
%{
151+
mod:
152+
case type do
153+
{_, mod} -> mod
154+
mod -> mod
155+
end,
156+
symbol: Enum.find(fields, fn f -> f.name == name end).type_name
157+
}
158+
end)
159+
|> Enum.reject(fn %{symbol: s} -> is_nil(s) end)
160+
|> Enum.reduce(acc, fn %{mod: mod, symbol: child_symbol}, acc ->
161+
child_symbol = Util.trim_symbol(child_symbol)
162+
163+
if child_symbol in nested_symbols do
164+
process_message(acc, child_symbol, mod, root_symbol)
165+
else
166+
process_message(acc, child_symbol, mod)
167+
end
168+
end)
169+
170+
_ ->
171+
acc
172+
end
173+
end
110174
end
111175

112-
defp build_response(symbol, module) do
113-
# we build our own file responses, so unwrap any present
114-
descriptor = get_descriptor(module)
115-
116-
dependencies =
117-
descriptor
118-
|> Util.types_from_descriptor()
119-
|> Enum.uniq()
120-
|> Kernel.--(Util.get_nested_types(symbol, descriptor))
121-
|> Enum.map(fn name ->
122-
Util.trim_symbol(name) <> ".proto"
123-
end)
124-
125-
syntax = Util.get_syntax(module)
176+
defp process_message(%Acc{} = acc, symbol, module) do
177+
process_message(acc, symbol, module, nil)
178+
end
126179

127-
response_stub =
128-
%FileDescriptorProto{
129-
name: symbol <> ".proto",
180+
defp register_symbol(%Acc{} = acc, symbol, module, kind) do
181+
symbol = Util.trim_symbol(symbol)
182+
183+
if Map.has_key?(acc.symbol_info, symbol) do
184+
{acc, true}
185+
else
186+
descriptor = get_descriptor(module)
187+
188+
info = %{
189+
descriptor: descriptor,
190+
deps:
191+
descriptor
192+
|> Util.types_from_descriptor()
193+
|> Enum.map(&Util.trim_symbol/1)
194+
|> Enum.uniq(),
195+
file: Util.proto_filename(module),
196+
syntax: Util.get_syntax(module),
130197
package: Util.get_package(symbol),
131-
dependency: dependencies,
132-
syntax: syntax
198+
kind: kind
133199
}
134200

135-
unencoded_payload =
136-
case descriptor = descriptor do
137-
%Google.Protobuf.DescriptorProto{} -> %{response_stub | message_type: [descriptor]}
138-
%Google.Protobuf.ServiceDescriptorProto{} -> %{response_stub | service: [descriptor]}
139-
%Google.Protobuf.EnumDescriptorProto{} -> %{response_stub | enum_type: [descriptor]}
140-
end
201+
{%{
202+
acc
203+
| symbol_info: Map.put(acc.symbol_info, symbol, info),
204+
visited: MapSet.put(acc.visited, symbol)
205+
}, false}
206+
end
207+
end
208+
209+
defp register_alias(%Acc{} = acc, alias_symbol, target_symbol) do
210+
alias_symbol = Util.trim_symbol(alias_symbol)
211+
target_symbol = Util.trim_symbol(target_symbol)
212+
213+
cond do
214+
alias_symbol == target_symbol -> acc
215+
Map.get(acc.aliases, alias_symbol) == target_symbol -> acc
216+
true -> %{acc | aliases: Map.put(acc.aliases, alias_symbol, target_symbol)}
217+
end
218+
end
141219

142-
%{file_descriptor_proto: [FileDescriptorProto.encode(unencoded_payload)]}
220+
defp resolve_dependency_file(nil, _acc), do: nil
221+
222+
defp resolve_dependency_file(symbol, %Acc{} = acc) do
223+
symbol = Util.trim_symbol(symbol)
224+
225+
cond do
226+
info = Map.get(acc.symbol_info, symbol) ->
227+
info.file
228+
229+
target = Map.get(acc.aliases, symbol) ->
230+
acc.symbol_info
231+
|> Map.get(target)
232+
|> case do
233+
nil -> symbol <> ".proto"
234+
info -> info.file
235+
end
236+
237+
true ->
238+
symbol <> ".proto"
239+
end
143240
end
144241

145-
# protoc with the elixir generator and protobuf.generate slightly differ for how they
146-
# generate descriptors. Use this to potentially unwrap the service proto when dealing
147-
# with descriptors that could come from a service module.
148242
defp get_descriptor(module) do
149243
case module.descriptor() do
150244
%FileDescriptorProto{service: [proto]} -> proto
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
defmodule GrpcReflection.Service.Builder.Acc do
2+
@moduledoc false
3+
4+
defstruct services: [],
5+
symbol_info: %{},
6+
aliases: %{},
7+
visited: MapSet.new(),
8+
extension_files: %{},
9+
extensions: %{}
10+
end

0 commit comments

Comments
 (0)