(t *testing.T)
| 141 | } |
| 142 | |
| 143 | func TestLBClientTransportAccessorsAndOverrides(t *testing.T) { |
| 144 | t.Parallel() |
| 145 | |
| 146 | hostWithoutOverrides := &fasthttp.HostClient{Addr: "example.com:80"} |
| 147 | nestedDialHost := &fasthttp.HostClient{Addr: "example.org:80"} |
| 148 | nestedTLSHost := &fasthttp.HostClient{Addr: "example.net:80", TLSConfig: &tls.Config{ServerName: "example", MinVersion: tls.VersionTLS12}} |
| 149 | multiLevelHost := &fasthttp.HostClient{Addr: "example.edu:80"} |
| 150 | |
| 151 | nestedDialHost.Dial = func(addr string) (net.Conn, error) { |
| 152 | _ = addr |
| 153 | return nil, errors.New("original dial") |
| 154 | } |
| 155 | |
| 156 | multiLevelHost.Dial = func(addr string) (net.Conn, error) { |
| 157 | _ = addr |
| 158 | return nil, errors.New("multi-level dial") |
| 159 | } |
| 160 | |
| 161 | nestedDialLB := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{nestedDialHost}}} |
| 162 | nestedTLSLB := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{nestedTLSHost}}} |
| 163 | multiLevelLeaf := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{multiLevelHost}}} |
| 164 | multiLevelWrapper := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{multiLevelLeaf}}} |
| 165 | |
| 166 | lb := &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{ |
| 167 | stubBalancingClient{}, |
| 168 | hostWithoutOverrides, |
| 169 | nestedDialLB, |
| 170 | nestedTLSLB, |
| 171 | multiLevelWrapper, |
| 172 | }} |
| 173 | |
| 174 | transport := newLBClientTransport(lb) |
| 175 | require.Same(t, lb, transport.Client()) |
| 176 | cfg := transport.TLSConfig() |
| 177 | require.Same(t, nestedTLSHost.TLSConfig, cfg) |
| 178 | |
| 179 | overrideTLS := &tls.Config{ServerName: "override", MinVersion: tls.VersionTLS12} |
| 180 | transport.SetTLSConfig(overrideTLS) |
| 181 | require.Equal(t, overrideTLS, hostWithoutOverrides.TLSConfig) |
| 182 | require.Equal(t, overrideTLS, nestedDialHost.TLSConfig) |
| 183 | require.Equal(t, overrideTLS, nestedTLSHost.TLSConfig) |
| 184 | require.Equal(t, overrideTLS, multiLevelHost.TLSConfig) |
| 185 | cfg = transport.TLSConfig() |
| 186 | require.Same(t, overrideTLS, cfg) |
| 187 | cfg.ServerName = "mutated" |
| 188 | require.Equal(t, "mutated", transport.TLSConfig().ServerName) |
| 189 | |
| 190 | overrideDialCalled := atomic.Bool{} |
| 191 | overrideDial := func(addr string) (net.Conn, error) { |
| 192 | _ = addr |
| 193 | overrideDialCalled.Store(true) |
| 194 | return nil, errors.New("override dial") |
| 195 | } |
| 196 | transport.SetDial(overrideDial) |
| 197 | overrideDialCalled.Store(false) |
| 198 | _, err := hostWithoutOverrides.Dial("example.com:80") |
| 199 | require.Error(t, err) |
| 200 | require.True(t, overrideDialCalled.Load()) |
nothing calls this directly
no test coverage detected